Skip to content

Commit 69fac7b

Browse files
committed
restore SDPA in gpt neo after 4.45
1 parent 29b2ac9 commit 69fac7b

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

optimum/exporters/openvino/model_configs.py

+20
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
FalconOnnxConfig,
3131
GemmaOnnxConfig,
3232
GPTJOnnxConfig,
33+
GPTNeoOnnxConfig,
3334
GPTNeoXOnnxConfig,
3435
IBertOnnxConfig,
3536
LlamaOnnxConfig,
@@ -68,6 +69,7 @@
6869
FluxTransfromerModelPatcher,
6970
Gemma2ModelPatcher,
7071
GptJModelPatcher,
72+
GptNeoModelPatcher,
7173
GptNeoxJapaneseModelPatcher,
7274
GptNeoxModelPatcher,
7375
IBertModelPatcher,
@@ -790,6 +792,24 @@ def patch_model_for_export(
790792
return GptNeoxJapaneseModelPatcher(self, model, model_kwargs=model_kwargs)
791793

792794

795+
@register_in_tasks_manager(
796+
"gpt-neo",
797+
*[
798+
"feature-extraction",
799+
"feature-extraction-with-past",
800+
"text-generation",
801+
"text-generation-with-past",
802+
"text-classification",
803+
],
804+
library_name="transformers",
805+
)
806+
class GPTNeoOpenVINOConfig(GPTNeoOnnxConfig):
807+
def patch_model_for_export(
808+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
809+
) -> "ModelPatcher":
810+
return GptNeoModelPatcher(self, model, model_kwargs=model_kwargs)
811+
812+
793813
@register_in_tasks_manager(
794814
"gptj",
795815
*[

optimum/exporters/openvino/model_patcher.py

+68
Original file line numberDiff line numberDiff line change
@@ -2654,6 +2654,74 @@ def __exit__(self, exc_type, exc_value, traceback):
26542654
unpatch_update_causal_mask(self._model, "gpt_neox_japanese")
26552655

26562656

2657+
def _gpt_neo_sdpa_attn(self, query, key, value, attention_mask=None, head_mask=None):
2658+
# Keep the attention weights computation in fp32 to avoid overflow issues
2659+
query = query.to(torch.float32)
2660+
key = key.to(torch.float32)
2661+
2662+
# Apply sliding window masking for local attention layers
2663+
query_length, key_length = query.size(-2), key.size(-2)
2664+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
2665+
# different from original for prevent overflow, apply to mask instead of directly to weights
2666+
mask_value = torch.finfo(torch.float16).min
2667+
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
2668+
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
2669+
mask_value = torch.tensor(mask_value, dtype=query.dtype).to(query.device)
2670+
if attention_mask is None:
2671+
attention_mask = torch.ones_like(causal_mask)
2672+
attention_mask = torch.where(causal_mask, attention_mask[:, :, :, : key.shape[-2]], mask_value)
2673+
2674+
# Mask heads if we want to
2675+
if head_mask is not None:
2676+
attention_mask = attention_mask * head_mask
2677+
2678+
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
2679+
2680+
return attn_output, None
2681+
2682+
2683+
def _gpt_neo_attn_forward(
2684+
self,
2685+
hidden_states,
2686+
attention_mask=None,
2687+
layer_past=None,
2688+
head_mask=None,
2689+
use_cache=False,
2690+
output_attentions=False,
2691+
cache_position=None,
2692+
):
2693+
if output_attentions:
2694+
self._attn = self._orig_attn
2695+
2696+
return self._orig_forward(
2697+
hidden_states,
2698+
attention_mask=attention_mask,
2699+
layer_past=layer_past,
2700+
head_mask=head_mask,
2701+
use_cache=use_cache,
2702+
output_attentions=output_attentions,
2703+
cache_position=cache_position,
2704+
)
2705+
2706+
2707+
class GptNeoModelPatcher(DecoderModelPatcher):
2708+
def __enter__(self):
2709+
super().__enter__()
2710+
if is_transformers_version(">=", "4.45.0") and is_torch_version(">=", "2.1.0"):
2711+
for layer in self._model.transformer.h:
2712+
self_attn = layer.attn.attention
2713+
self_attn._orig_attn = self_attn._attn
2714+
self_attn._attn = types.MethodType(_gpt_neo_sdpa_attn, self_attn)
2715+
self_attn._orig_forward = types.MethodType(_gpt_neo_attn_forward, self_attn)
2716+
2717+
def __exit__(self, exc_type, exc_value, traceback):
2718+
super().__exit__(exc_type, exc_value, traceback)
2719+
for layer in self._model.transformer.h:
2720+
if hasattr(layer.attn.attention, "_orig_forward"):
2721+
layer.attn.attention.forward = layer.attn.attention._orig_forward
2722+
layer.attn.attention._attn = layer.attn.attention._orig_attn
2723+
2724+
26572725
class Gemma2ModelPatcher(LlamaModelPatcher):
26582726
def __init__(
26592727
self,

0 commit comments

Comments
 (0)