Skip to content

Commit 35c47a2

Browse files
force attn model
1 parent 45133cb commit 35c47a2

File tree

2 files changed

+7
-41
lines changed

2 files changed

+7
-41
lines changed

optimum/exporters/openvino/__main__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
)
5050

5151

52-
FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager"}
52+
FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager", "gemma2": "sdpa"}
5353

5454
if TYPE_CHECKING:
5555
from optimum.intel.openvino.configuration import OVConfig

optimum/exporters/openvino/model_patcher.py

+6-40
Original file line numberDiff line numberDiff line change
@@ -421,9 +421,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c
421421
offset = 0
422422
mask_shape = attention_mask.shape
423423
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
424-
causal_mask[
425-
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
426-
] = mask_slice
424+
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
425+
mask_slice
426+
)
427427

428428
if (
429429
self.config._attn_implementation == "sdpa"
@@ -2058,9 +2058,9 @@ def _dbrx_update_causal_mask_legacy(
20582058
offset = 0
20592059
mask_shape = attention_mask.shape
20602060
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
2061-
causal_mask[
2062-
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
2063-
] = mask_slice
2061+
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
2062+
mask_slice
2063+
)
20642064

20652065
if (
20662066
self.config._attn_implementation == "sdpa"
@@ -2710,40 +2710,6 @@ def patched_forward(*args, **kwargs):
27102710

27112711
self.patched_forward = patched_forward
27122712

2713-
def __enter__(self):
2714-
super().__enter__()
2715-
2716-
if is_transformers_version(">=", "4.47.0"):
2717-
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_FUNCTION
2718-
2719-
GEMMA2_ATTENTION_FUNCTION["original_eager"] = GEMMA2_ATTENTION_FUNCTION["eager"]
2720-
GEMMA2_ATTENTION_FUNCTION["eager"] = GEMMA2_ATTENTION_FUNCTION["sdpa"]
2721-
2722-
elif is_transformers_version(">=", "4.45.0"):
2723-
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES
2724-
2725-
sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"]
2726-
eager_attn = GEMMA2_ATTENTION_CLASSES["eager"]
2727-
2728-
for layer in self._model.model.layers:
2729-
if isinstance(layer.self_attn, eager_attn):
2730-
layer.self_attn._orig_forward = layer.self_attn.forward
2731-
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
2732-
2733-
def __exit__(self, exc_type, exc_value, traceback):
2734-
super().__exit__(exc_type, exc_value, traceback)
2735-
2736-
if is_transformers_version(">=", "4.47.0"):
2737-
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_FUNCTION
2738-
2739-
GEMMA2_ATTENTION_FUNCTION["eager"] = GEMMA2_ATTENTION_FUNCTION["original_eager"]
2740-
del GEMMA2_ATTENTION_FUNCTION["original_eager"]
2741-
2742-
elif is_transformers_version(">=", "4.45.0"):
2743-
for layer in self._model.model.layers:
2744-
if hasattr(layer.self_attn, "_orig_forward"):
2745-
layer.self_attn.forward = layer.self_attn._orig_forward
2746-
27472713

27482714
def _decilm_attn_forward(
27492715
self,

0 commit comments

Comments
 (0)