Skip to content

Commit 4bf5ffc

Browse files
couple initial fixes
1 parent 3ad88da commit 4bf5ffc

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

optimum/exporters/openvino/model_patcher.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -718,14 +718,15 @@ def _mistral_update_causal_mask(
718718
class MistralModelPatcher(DecoderModelPatcher):
719719
def __enter__(self):
720720
super().__enter__()
721-
if is_transformers_version(">=", "4.42.0"):
721+
if is_transformers_version(">=", "4.42.0") and is_transformers_version("<", "4.48.0"):
722722
# apply fix https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548
723723
self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask
724724
self._model.model._update_causal_mask = types.MethodType(_mistral_update_causal_mask, self._model.model)
725725

726726
else:
727727
for layer in self._model.model.layers:
728-
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
728+
if hasattr(layer.self_attn, "rotary_emb"):
729+
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
729730

730731
def __exit__(self, exc_type, exc_value, traceback):
731732
super().__exit__(exc_type, exc_value, traceback)
@@ -734,7 +735,7 @@ def __exit__(self, exc_type, exc_value, traceback):
734735
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
735736

736737
for layer in self._model.model.layers:
737-
if hasattr(layer.self_attn.rotary_emb, "_orig_forward"):
738+
if hasattr(layer.self_attn, "rotary_emb") and hasattr(layer.self_attn.rotary_emb, "_orig_forward"):
738739
layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward
739740

740741

@@ -2493,7 +2494,9 @@ class UpdateCausalMaskModelPatcher(DecoderModelPatcher):
24932494
def __enter__(self):
24942495
super().__enter__()
24952496
patch_update_causal_mask(self._model, "4.42.0")
2496-
if hasattr(self._model.model.layers[0].self_attn.rotary_emb, "_set_cos_sin_cache"):
2497+
if hasattr(self._model.model.layers[0].self_attn, "rotary_emb") and hasattr(
2498+
self._model.model.layers[0].self_attn.rotary_emb, "_set_cos_sin_cache"
2499+
):
24972500
for layer in self._model.model.layers:
24982501
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
24992502

@@ -3045,15 +3048,16 @@ def patched_forward(self, fn):
30453048
def __enter__(self):
30463049
if is_torch_version(">=", "2.1.0"):
30473050
if self._model.config.model_type == "qwen2" and self._model.config._attn_implementation != "sdpa":
3048-
from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES
3051+
if is_transformers_version("<", "4.48"):
3052+
from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES
30493053

3050-
sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"]
3051-
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
3052-
self._model.config._attn_implementation = "sdpa"
3054+
sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"]
3055+
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
3056+
self._model.config._attn_implementation = "sdpa"
30533057

3054-
for layer in self._model.model.layers:
3055-
layer.self_attn._orig_forward = layer.self_attn.forward
3056-
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
3058+
for layer in self._model.model.layers:
3059+
layer.self_attn._orig_forward = layer.self_attn.forward
3060+
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
30573061

30583062
if self._model.config.model_type == "llama" and self._model.config._attn_implementation != "sdpa":
30593063
self._model.config._orig_attn_implementation = self._model.config._attn_implementation

0 commit comments

Comments
 (0)