Skip to content

Commit b8b419d

Browse files
committed
disable flash_attn during export internvl2
1 parent fe10aaa commit b8b419d

File tree

2 files changed

+87
-1
lines changed

2 files changed

+87
-1
lines changed

optimum/exporters/openvino/model_configs.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
InputEmbeddingPatcher,
8888
InternLM2Patcher,
8989
InternLMModelPatcher,
90+
InternVL2ChatLangModelPatcher,
9091
InternVLChatImageEmbeddingModelPatcher,
9192
JaisModelPatcher,
9293
LlamaModelPatcher,
@@ -1642,7 +1643,11 @@ def with_behavior(
16421643
if behavior == InternVLChatConfigBehavior.LANGUAGE:
16431644
model_type = self._orig_config.llm_config.model_type
16441645
return get_vlm_text_generation_config(
1645-
model_type, self._orig_config.llm_config, self.int_dtype, self.float_dtype
1646+
model_type,
1647+
self._orig_config.llm_config,
1648+
self.int_dtype,
1649+
self.float_dtype,
1650+
InternVL2ChatLangModelPatcher,
16461651
)
16471652

16481653
if behavior == InternVLChatConfigBehavior.VISION_EMBEDDINGS:

optimum/exporters/openvino/model_patcher.py

+81
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import torch
2323
import torch.nn.functional as F
24+
from transformers import PreTrainedModel, TFPreTrainedModel
2425
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
2526
from transformers.utils import is_tf_available
2627

@@ -2992,11 +2993,91 @@ def __init__(
29922993
model.__orig_forward = model.forward
29932994
model.forward = model.extract_feature
29942995

2996+
if model.vision_model.encoder.layers[0].attn.use_flash_attn:
2997+
for layer in model.vision_model.encoder.layers:
2998+
layer.attn._orig_use_flash_attn = layer.attn.use_flash_attn
2999+
layer.attn.use_flash_attn = False
3000+
29953001
super().__init__(config, model, model_kwargs)
29963002

29973003
def __exit__(self, exc_type, exc_value, traceback):
29983004
super().__exit__(exc_type, exc_value, traceback)
29993005
self._model.forward = self._model.__orig_forward
3006+
if hasattr(self._model.vision_model.encoder.layers[0].attn, "_orig_use_flash_attn"):
3007+
for layer in self._model.vision_model.encoder.layers:
3008+
layer.attn.use_flash_attn = layer.attn._orig_use_flash_attn
3009+
3010+
3011+
class InternVL2ChatLangModelPatcher(DecoderModelPatcher):
3012+
def __init__(
3013+
self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any]
3014+
):
3015+
model_type = model.config.model_type
3016+
patcher_for_model_type = {
3017+
"llama": LlamaModelPatcher,
3018+
"qwen2": UpdateCausalMaskModelPatcher,
3019+
"phi3": Phi3ModelPatcher,
3020+
"internlm2": InternLM2Patcher,
3021+
}
3022+
self._internal_patcher = None
3023+
self._patched_forward = None
3024+
internal_patcher_cls = patcher_for_model_type.get(model_type)
3025+
if internal_patcher_cls is not None:
3026+
self._internal_patcher = internal_patcher_cls(config, model, model_kwargs)
3027+
self._patched_forward = self._internal_patcher.patched_forward
3028+
super().__init__(config, model, model_kwargs)
3029+
3030+
@property
3031+
def patched_forward(self):
3032+
if self._internal_patcher is not None:
3033+
return self._internal_patcher.patched_forward
3034+
return self._patched_forward
3035+
3036+
@patched_forward.setter
3037+
def patched_forward(self, fn):
3038+
self._patched_forward = fn
3039+
if self._internal_patcher is not None:
3040+
self._internal_patcher.patched_forward = fn
3041+
3042+
def __enter__(self):
3043+
if is_torch_version(">=", "2.1.0"):
3044+
if self._model.config.model_type == "qwen2" and self._model.config._attn_implementation != "sdpa":
3045+
from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES
3046+
3047+
sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"]
3048+
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
3049+
self._model.config._attn_implementation = "sdpa"
3050+
3051+
for layer in self._model.model.layers:
3052+
layer.self_attn._orig_forward = layer.self_attn.forward
3053+
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
3054+
3055+
if self._model.config.model_type == "llama" and self._model.config._attn_implementation != "sdpa":
3056+
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
3057+
self._model.config._attn_implementation = "sdpa"
3058+
if is_transformers_version("<", "4.47"):
3059+
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
3060+
3061+
sdpa_attn = LLAMA_ATTENTION_CLASSES["sdpa"]
3062+
for layer in self._model.model.layers:
3063+
layer.self_attn._orig_forward = layer.self_attn.forward
3064+
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
3065+
3066+
if self._internal_patcher is not None:
3067+
return self._internal_patcher.__enter__()
3068+
return super().__enter__()
3069+
3070+
def __exit__(self, exc_type, exc_value, traceback):
3071+
if self._internal_patcher:
3072+
self._internal_patcher.__exit__(exc_type, exc_value, traceback)
3073+
else:
3074+
super().__exit__(exc_type, exc_value, traceback)
3075+
3076+
if hasattr(self._model.config, "_orig_attn_implementation"):
3077+
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
3078+
for layer in self._model.model.layers:
3079+
if hasattr(layer.self_attn, "_orig_forward"):
3080+
layer.self_attn.forward = layer.self_attn._orig_forward
30003081

30013082

30023083
def llava_vision_embed_forward(self, pixel_values):

0 commit comments

Comments
 (0)