Skip to content

Commit 4b29d9e

Browse files
committed
disable flash_attn during export internvl2
1 parent 190ae87 commit 4b29d9e

File tree

2 files changed

+88
-1
lines changed

2 files changed

+88
-1
lines changed

optimum/exporters/openvino/model_configs.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
InputEmbeddingPatcher,
8585
InternLM2Patcher,
8686
InternLMModelPatcher,
87+
InternVL2ChatLangModelPatcher,
8788
InternVLChatImageEmbeddingModelPatcher,
8889
JaisModelPatcher,
8990
LlamaModelPatcher,
@@ -1638,7 +1639,11 @@ def with_behavior(
16381639
if behavior == InternVLChatConfigBehavior.LANGUAGE:
16391640
model_type = self._orig_config.llm_config.model_type
16401641
return get_vlm_text_generation_config(
1641-
model_type, self._orig_config.llm_config, self.int_dtype, self.float_dtype
1642+
model_type,
1643+
self._orig_config.llm_config,
1644+
self.int_dtype,
1645+
self.float_dtype,
1646+
InternVL2ChatLangModelPatcher,
16421647
)
16431648

16441649
if behavior == InternVLChatConfigBehavior.VISION_EMBEDDINGS:

optimum/exporters/openvino/model_patcher.py

+82
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121

2222
import torch
2323
import torch.nn.functional as F
24+
from transformers import PreTrainedModel, TFPreTrainedModel
2425
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
2526
from transformers.utils import is_tf_available
2627

28+
from optimum.exporters.onnx.base import OnnxConfig
2729
from optimum.exporters.onnx.model_patcher import DecoderModelPatcher, ModelPatcher, override_arguments
2830
from optimum.intel.utils.import_utils import (
2931
_openvino_version,
@@ -2987,11 +2989,91 @@ def __init__(
29872989
model.__orig_forward = model.forward
29882990
model.forward = model.extract_feature
29892991

2992+
if model.vision_model.encoder.layers[0].attn.use_flash_attn:
2993+
for layer in model.vision_model.encoder.layers:
2994+
layer.attn._orig_use_flash_attn = layer.attn.use_flash_attn
2995+
layer.attn.use_flash_attn = False
2996+
29902997
super().__init__(config, model, model_kwargs)
29912998

29922999
def __exit__(self, exc_type, exc_value, traceback):
29933000
super().__exit__(exc_type, exc_value, traceback)
29943001
self._model.forward = self._model.__orig_forward
3002+
if hasattr(self._model.vision_model.encoder.layers[0].attn, "_orig_use_flash_attn"):
3003+
for layer in self._model.vision_model.encoder.layers:
3004+
layer.attn.use_flash_attn = layer.attn._orig_use_flash_attn
3005+
3006+
3007+
class InternVL2ChatLangModelPatcher(DecoderModelPatcher):
3008+
def __init__(
3009+
self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any]
3010+
):
3011+
model_type = model.config.model_type
3012+
patcher_for_model_type = {
3013+
"llama": LlamaModelPatcher,
3014+
"qwen2": UpdateCausalMaskModelPatcher,
3015+
"phi3": Phi3ModelPatcher,
3016+
"internlm2": InternLM2Patcher,
3017+
}
3018+
self._internal_patcher = None
3019+
self._patched_forward = None
3020+
internal_patcher_cls = patcher_for_model_type.get(model_type)
3021+
if internal_patcher_cls is not None:
3022+
self._internal_patcher = internal_patcher_cls(config, model, model_kwargs)
3023+
self._patched_forward = self._internal_patcher.patched_forward
3024+
super().__init__(config, model, model_kwargs)
3025+
3026+
@property
3027+
def patched_forward(self):
3028+
if self._internal_patcher is not None:
3029+
return self._internal_patcher.patched_forward
3030+
return self._patched_forward
3031+
3032+
@patched_forward.setter
3033+
def patched_forward(self, fn):
3034+
self._patched_forward = fn
3035+
if self._internal_patcher is not None:
3036+
self._internal_patcher.patched_forward = fn
3037+
3038+
def __enter__(self):
3039+
if is_torch_version(">=", "2.1.0"):
3040+
if self._model.config.model_type == "qwen2" and self._model.config._attn_implementation != "sdpa":
3041+
from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES
3042+
3043+
sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"]
3044+
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
3045+
self._model.config._attn_implementation = "sdpa"
3046+
3047+
for layer in self._model.model.layers:
3048+
layer.self_attn._orig_forward = layer.self_attn.forward
3049+
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
3050+
3051+
if self._model.config.model_type == "llama" and self._model.config._attn_implementation != "sdpa":
3052+
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
3053+
self._model.config._attn_implementation = "sdpa"
3054+
if is_transformers_version("<", "4.47"):
3055+
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
3056+
3057+
sdpa_attn = LLAMA_ATTENTION_CLASSES["sdpa"]
3058+
for layer in self._model.model.layers:
3059+
layer.self_attn._orig_forward = layer.self_attn.forward
3060+
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
3061+
3062+
if self._internal_patcher is not None:
3063+
return self._internal_patcher.__enter__()
3064+
return super().__enter__()
3065+
3066+
def __exit__(self, exc_type, exc_value, traceback):
3067+
if self._internal_patcher:
3068+
self._internal_patcher.__exit__(exc_type, exc_value, traceback)
3069+
else:
3070+
super().__exit__(exc_type, exc_value, traceback)
3071+
3072+
if hasattr(self._model.config, "_orig_attn_implementation"):
3073+
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
3074+
for layer in self._model.model.layers:
3075+
if hasattr(layer.self_attn, "_orig_forward"):
3076+
layer.self_attn.forward = layer.self_attn._orig_forward
29953077

29963078

29973079
def llava_vision_embed_forward(self, pixel_values):

0 commit comments

Comments
 (0)