Skip to content

Commit 35f6fb6

Browse files
patch gemma attn functions
1 parent 3e2cf34 commit 35f6fb6

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

optimum/exporters/openvino/model_patcher.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -421,9 +421,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c
421421
offset = 0
422422
mask_shape = attention_mask.shape
423423
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
424-
causal_mask[
425-
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
426-
] = mask_slice
424+
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
425+
mask_slice
426+
)
427427

428428
if (
429429
self.config._attn_implementation == "sdpa"
@@ -2058,9 +2058,9 @@ def _dbrx_update_causal_mask_legacy(
20582058
offset = 0
20592059
mask_shape = attention_mask.shape
20602060
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
2061-
causal_mask[
2062-
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
2063-
] = mask_slice
2061+
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
2062+
mask_slice
2063+
)
20642064

20652065
if (
20662066
self.config._attn_implementation == "sdpa"
@@ -2712,7 +2712,14 @@ def patched_forward(*args, **kwargs):
27122712

27132713
def __enter__(self):
27142714
super().__enter__()
2715-
if is_transformers_version(">=", "4.45.0"):
2715+
2716+
if is_transformers_version(">=", "4.47.0"):
2717+
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_FUNCTION
2718+
2719+
GEMMA2_ATTENTION_FUNCTION["original_eager"] = GEMMA2_ATTENTION_FUNCTION["eager"]
2720+
GEMMA2_ATTENTION_FUNCTION["eager"] = GEMMA2_ATTENTION_FUNCTION["sdpa"]
2721+
2722+
elif is_transformers_version(">=", "4.45.0"):
27162723
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES
27172724

27182725
sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"]
@@ -2725,7 +2732,14 @@ def __enter__(self):
27252732

27262733
def __exit__(self, exc_type, exc_value, traceback):
27272734
super().__exit__(exc_type, exc_value, traceback)
2728-
if is_transformers_version(">=", "4.45.0"):
2735+
2736+
if is_transformers_version(">=", "4.47.0"):
2737+
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_FUNCTION
2738+
2739+
GEMMA2_ATTENTION_FUNCTION["eager"] = GEMMA2_ATTENTION_FUNCTION["original_eager"]
2740+
del GEMMA2_ATTENTION_FUNCTION["original_eager"]
2741+
2742+
elif is_transformers_version(">=", "4.45.0"):
27292743
for layer in self._model.model.layers:
27302744
if hasattr(layer.self_attn, "_orig_forward"):
27312745
layer.self_attn.forward = layer.self_attn._orig_forward

0 commit comments

Comments
 (0)