Skip to content

Commit a43f495

Browse files
committed
fix internvl2 patching for transformes>=4.48
1 parent 7055210 commit a43f495

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

optimum/exporters/openvino/model_patcher.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -3051,22 +3051,22 @@ def patched_forward(self, fn):
30513051

30523052
def __enter__(self):
30533053
if is_torch_version(">=", "2.1.0"):
3054-
if self._model.config.model_type == "qwen2" and self._model.config._attn_implementation != "sdpa":
3055-
if is_transformers_version("<", "4.48"):
3054+
if (
3055+
self._model.config.model_type in ["qwen2", "llama"]
3056+
and self._model.config._attn_implementation != "sdpa"
3057+
):
3058+
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
3059+
self._model.config._attn_implementation = "sdpa"
3060+
if self._model.config.model_type == "qwen2" and is_transformers_version("<", "4.48"):
30563061
from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES
30573062

30583063
sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"]
3059-
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
3060-
self._model.config._attn_implementation = "sdpa"
30613064

30623065
for layer in self._model.model.layers:
30633066
layer.self_attn._orig_forward = layer.self_attn.forward
30643067
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
30653068

3066-
if self._model.config.model_type == "llama" and self._model.config._attn_implementation != "sdpa":
3067-
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
3068-
self._model.config._attn_implementation = "sdpa"
3069-
if is_transformers_version("<", "4.47"):
3069+
if self._model.config.model_type == "llama" and is_transformers_version("<", "4.47"):
30703070
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
30713071

30723072
sdpa_attn = LLAMA_ATTENTION_CLASSES["sdpa"]

0 commit comments

Comments
 (0)