Skip to content

Commit 936d272

Browse files
authored
Restore SDPA in Gemma2 models for transformers > 4.45 (#976)
* Restore SDPA in Gemma2 models for transformers > 4.45 * Update tests/openvino/test_modeling.py * Update tests/openvino/test_modeling.py
1 parent 635f939 commit 936d272

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

optimum/exporters/openvino/model_patcher.py

+20
Original file line numberDiff line numberDiff line change
@@ -2505,6 +2505,26 @@ def patched_forward(*args, **kwargs):
25052505

25062506
self.patched_forward = patched_forward
25072507

2508+
def __enter__(self):
2509+
super().__enter__()
2510+
if is_transformers_version(">=", "4.45.0"):
2511+
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES
2512+
2513+
sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"]
2514+
eager_attn = GEMMA2_ATTENTION_CLASSES["eager"]
2515+
2516+
for layer in self._model.model.layers:
2517+
if isinstance(layer.self_attn, eager_attn):
2518+
layer.self_attn._orig_forward = layer.self_attn.forward
2519+
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
2520+
2521+
def __exit__(self, exc_type, exc_value, traceback):
2522+
super().__exit__(exc_type, exc_value, traceback)
2523+
if is_transformers_version(">=", "4.45.0"):
2524+
for layer in self._model.model.layers:
2525+
if hasattr(layer.self_attn, "_orig_forward"):
2526+
layer.self_attn.forward = layer.self_attn._orig_forward
2527+
25082528

25092529
def _decilm_attn_forward(
25102530
self,

tests/openvino/test_modeling.py

+8
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,10 @@ def test_compare_to_transformers(self, model_arch):
863863
if model_arch in self.REMOTE_CODE_MODELS:
864864
model_kwargs = {"trust_remote_code": True}
865865

866+
# starting from transformers 4.45.0 gemma2 uses eager attention by default, while ov - sdpa
867+
if model_arch == "gemma2" and is_transformers_version(">=", "4.45.0"):
868+
model_kwargs["attn_implementation"] = "sdpa"
869+
866870
ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG, **model_kwargs)
867871
self.assertIsInstance(ov_model.config, PretrainedConfig)
868872
self.assertTrue(ov_model.use_cache)
@@ -1094,6 +1098,10 @@ def test_beam_search(self, model_arch):
10941098
"config": AutoConfig.from_pretrained(model_id, trust_remote_code=True),
10951099
"trust_remote_code": True,
10961100
}
1101+
1102+
# starting from transformers 4.45.0 gemma2 uses eager attention by default, while ov - sdpa
1103+
if model_arch == "gemma2" and is_transformers_version(">=", "4.45.0"):
1104+
model_kwargs["attn_implementation"] = "sdpa"
10971105
# Qwen tokenizer does not support padding, chatglm, glm4 testing models produce nan that incompatible with beam search
10981106
if model_arch in ["qwen", "chatglm", "glm4"]:
10991107
return

0 commit comments

Comments
 (0)