Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restore SDPA in gpt neo after 4.45 #1092

Merged
merged 3 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
FalconOnnxConfig,
GemmaOnnxConfig,
GPTJOnnxConfig,
GPTNeoOnnxConfig,
GPTNeoXOnnxConfig,
IBertOnnxConfig,
LlamaOnnxConfig,
Expand Down Expand Up @@ -68,6 +69,7 @@
FluxTransfromerModelPatcher,
Gemma2ModelPatcher,
GptJModelPatcher,
GptNeoModelPatcher,
GptNeoxJapaneseModelPatcher,
GptNeoxModelPatcher,
IBertModelPatcher,
Expand Down Expand Up @@ -790,6 +792,24 @@ def patch_model_for_export(
return GptNeoxJapaneseModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"gpt-neo",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
],
library_name="transformers",
)
class GPTNeoOpenVINOConfig(GPTNeoOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return GptNeoModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"gptj",
*[
Expand Down
90 changes: 90 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2681,6 +2681,96 @@ def __exit__(self, exc_type, exc_value, traceback):
unpatch_update_causal_mask(self._model, "gpt_neox_japanese")


def _gpt_neo_attn_forward(
self,
hidden_states,
attention_mask=None,
layer_past=None,
head_mask=None,
use_cache=False,
output_attentions=False,
cache_position=None,
):
if output_attentions:
self._attn = self._orig_attn

return self._orig_forward(
hidden_states,
attention_mask=attention_mask,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)


# Adopted from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py#L185
def _gpt_neo_attn_sdpa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
batch_size = query.shape[0]

mask_value = torch.finfo(torch.float16).min
mask_value = torch.full([], mask_value, dtype=value.dtype)

dropout_p = float(self.config.attention_dropout) if self.training else 0.0
if (batch_size == 1 or self.training) and self.attention_type == "global":
if query.shape[2] > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=1.0
)
else:
query_length, key_length = query.size(-2), key.size(-2)

causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]

causal_mask = torch.where(causal_mask, 0, mask_value)
if batch_size > 1:
# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)

if attention_mask is not None:
attention_mask = causal_mask + attention_mask

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False, scale=1.0
)

return sdpa_result, None


class GptNeoModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.45.0") and is_torch_version(">=", "2.1.0"):
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"
for layer in self._model.transformer.h:
self_attn = layer.attn.attention
self_attn._orig_attn = self_attn._attn
self_attn._attn = types.MethodType(_gpt_neo_attn_sdpa, self_attn)
self_attn._orig_forward = types.MethodType(_gpt_neo_attn_forward, self_attn)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.config, "_orig_attn_implementation"):
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
for layer in self._model.transformer.h:
for layer in self._model.transformer.h:
layer.attn.attention.forward = layer.attn.attention._orig_forward
layer.attn.attention._attn = layer.attn.attention._orig_attn


class Gemma2ModelPatcher(LlamaModelPatcher):
def __init__(
self,
Expand Down
10 changes: 9 additions & 1 deletion tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,15 @@ def test_beam_search(self, model_arch):
transformers_model._supports_cache_class = True
from transformers.cache_utils import DynamicCache
tokenizer.pad_token_id = tokenizer.eos_token_id
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
tokenization_args = {}
if is_transformers_version(">=", "4.45") and model_arch == "gpt_neo":
tokenization_args["padding_side"] = "left"
tokens = tokenizer(
["Today is a nice day and I am longer", "This is me"],
return_tensors="pt",
padding=True,
**tokenization_args,
)
ov_model_stateful.generation_config.eos_token_id = None
ov_model_stateless.generation_config.eos_token_id = None
transformers_model.generation_config.eos_token_id = None
Expand Down