Skip to content

Commit 1733791

Browse files
authored
restore SDPA in gpt neo after 4.45 (#1092)
* restore SDPA in gpt neo after 4.45 * fix accuracy * left padding
1 parent 014a840 commit 1733791

File tree

3 files changed

+119
-1
lines changed

3 files changed

+119
-1
lines changed

optimum/exporters/openvino/model_configs.py

+20
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
FalconOnnxConfig,
3131
GemmaOnnxConfig,
3232
GPTJOnnxConfig,
33+
GPTNeoOnnxConfig,
3334
GPTNeoXOnnxConfig,
3435
IBertOnnxConfig,
3536
LlamaOnnxConfig,
@@ -68,6 +69,7 @@
6869
FluxTransfromerModelPatcher,
6970
Gemma2ModelPatcher,
7071
GptJModelPatcher,
72+
GptNeoModelPatcher,
7173
GptNeoxJapaneseModelPatcher,
7274
GptNeoxModelPatcher,
7375
IBertModelPatcher,
@@ -790,6 +792,24 @@ def patch_model_for_export(
790792
return GptNeoxJapaneseModelPatcher(self, model, model_kwargs=model_kwargs)
791793

792794

795+
@register_in_tasks_manager(
796+
"gpt-neo",
797+
*[
798+
"feature-extraction",
799+
"feature-extraction-with-past",
800+
"text-generation",
801+
"text-generation-with-past",
802+
"text-classification",
803+
],
804+
library_name="transformers",
805+
)
806+
class GPTNeoOpenVINOConfig(GPTNeoOnnxConfig):
807+
def patch_model_for_export(
808+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
809+
) -> "ModelPatcher":
810+
return GptNeoModelPatcher(self, model, model_kwargs=model_kwargs)
811+
812+
793813
@register_in_tasks_manager(
794814
"gptj",
795815
*[

optimum/exporters/openvino/model_patcher.py

+90
Original file line numberDiff line numberDiff line change
@@ -2681,6 +2681,96 @@ def __exit__(self, exc_type, exc_value, traceback):
26812681
unpatch_update_causal_mask(self._model, "gpt_neox_japanese")
26822682

26832683

2684+
def _gpt_neo_attn_forward(
2685+
self,
2686+
hidden_states,
2687+
attention_mask=None,
2688+
layer_past=None,
2689+
head_mask=None,
2690+
use_cache=False,
2691+
output_attentions=False,
2692+
cache_position=None,
2693+
):
2694+
if output_attentions:
2695+
self._attn = self._orig_attn
2696+
2697+
return self._orig_forward(
2698+
hidden_states,
2699+
attention_mask=attention_mask,
2700+
layer_past=layer_past,
2701+
head_mask=head_mask,
2702+
use_cache=use_cache,
2703+
output_attentions=output_attentions,
2704+
cache_position=cache_position,
2705+
)
2706+
2707+
2708+
# Adopted from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py#L185
2709+
def _gpt_neo_attn_sdpa(
2710+
self,
2711+
query: torch.Tensor,
2712+
key: torch.Tensor,
2713+
value: torch.Tensor,
2714+
attention_mask: Optional[torch.Tensor] = None,
2715+
head_mask: Optional[torch.Tensor] = None,
2716+
):
2717+
batch_size = query.shape[0]
2718+
2719+
mask_value = torch.finfo(torch.float16).min
2720+
mask_value = torch.full([], mask_value, dtype=value.dtype)
2721+
2722+
dropout_p = float(self.config.attention_dropout) if self.training else 0.0
2723+
if (batch_size == 1 or self.training) and self.attention_type == "global":
2724+
if query.shape[2] > 1:
2725+
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
2726+
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
2727+
)
2728+
else:
2729+
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
2730+
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=1.0
2731+
)
2732+
else:
2733+
query_length, key_length = query.size(-2), key.size(-2)
2734+
2735+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
2736+
2737+
causal_mask = torch.where(causal_mask, 0, mask_value)
2738+
if batch_size > 1:
2739+
# torch.Tensor.expand does no memory copy
2740+
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
2741+
2742+
if attention_mask is not None:
2743+
attention_mask = causal_mask + attention_mask
2744+
2745+
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
2746+
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False, scale=1.0
2747+
)
2748+
2749+
return sdpa_result, None
2750+
2751+
2752+
class GptNeoModelPatcher(DecoderModelPatcher):
2753+
def __enter__(self):
2754+
super().__enter__()
2755+
if is_transformers_version(">=", "4.45.0") and is_torch_version(">=", "2.1.0"):
2756+
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
2757+
self._model.config._attn_implementation = "sdpa"
2758+
for layer in self._model.transformer.h:
2759+
self_attn = layer.attn.attention
2760+
self_attn._orig_attn = self_attn._attn
2761+
self_attn._attn = types.MethodType(_gpt_neo_attn_sdpa, self_attn)
2762+
self_attn._orig_forward = types.MethodType(_gpt_neo_attn_forward, self_attn)
2763+
2764+
def __exit__(self, exc_type, exc_value, traceback):
2765+
super().__exit__(exc_type, exc_value, traceback)
2766+
if hasattr(self._model.config, "_orig_attn_implementation"):
2767+
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
2768+
for layer in self._model.transformer.h:
2769+
for layer in self._model.transformer.h:
2770+
layer.attn.attention.forward = layer.attn.attention._orig_forward
2771+
layer.attn.attention._attn = layer.attn.attention._orig_attn
2772+
2773+
26842774
class Gemma2ModelPatcher(LlamaModelPatcher):
26852775
def __init__(
26862776
self,

tests/openvino/test_modeling.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1296,7 +1296,15 @@ def test_beam_search(self, model_arch):
12961296
transformers_model._supports_cache_class = True
12971297
from transformers.cache_utils import DynamicCache
12981298
tokenizer.pad_token_id = tokenizer.eos_token_id
1299-
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
1299+
tokenization_args = {}
1300+
if is_transformers_version(">=", "4.45") and model_arch == "gpt_neo":
1301+
tokenization_args["padding_side"] = "left"
1302+
tokens = tokenizer(
1303+
["Today is a nice day and I am longer", "This is me"],
1304+
return_tensors="pt",
1305+
padding=True,
1306+
**tokenization_args,
1307+
)
13001308
ov_model_stateful.generation_config.eos_token_id = None
13011309
ov_model_stateless.generation_config.eos_token_id = None
13021310
transformers_model.generation_config.eos_token_id = None

0 commit comments

Comments
 (0)