|
21 | 21 |
|
22 | 22 | import torch
|
23 | 23 | import torch.nn.functional as F
|
| 24 | +from transformers import PreTrainedModel, TFPreTrainedModel |
24 | 25 | from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
|
25 | 26 | from transformers.utils import is_tf_available
|
26 | 27 |
|
| 28 | +from optimum.exporters.onnx.base import OnnxConfig |
27 | 29 | from optimum.exporters.onnx.model_patcher import DecoderModelPatcher, ModelPatcher, override_arguments
|
28 | 30 | from optimum.intel.utils.import_utils import (
|
29 | 31 | _openvino_version,
|
@@ -2987,11 +2989,91 @@ def __init__(
|
2987 | 2989 | model.__orig_forward = model.forward
|
2988 | 2990 | model.forward = model.extract_feature
|
2989 | 2991 |
|
| 2992 | + if model.vision_model.encoder.layers[0].attn.use_flash_attn: |
| 2993 | + for layer in model.vision_model.encoder.layers: |
| 2994 | + layer.attn._orig_use_flash_attn = layer.attn.use_flash_attn |
| 2995 | + layer.attn.use_flash_attn = False |
| 2996 | + |
2990 | 2997 | super().__init__(config, model, model_kwargs)
|
2991 | 2998 |
|
2992 | 2999 | def __exit__(self, exc_type, exc_value, traceback):
|
2993 | 3000 | super().__exit__(exc_type, exc_value, traceback)
|
2994 | 3001 | self._model.forward = self._model.__orig_forward
|
| 3002 | + if hasattr(self._model.vision_model.encoder.layers[0].attn, "_orig_use_flash_attn"): |
| 3003 | + for layer in self._model.vision_model.encoder.layers: |
| 3004 | + layer.attn.use_flash_attn = layer.attn._orig_use_flash_attn |
| 3005 | + |
| 3006 | + |
| 3007 | +class InternVL2ChatLangModelPatcher(DecoderModelPatcher): |
| 3008 | + def __init__( |
| 3009 | + self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any] |
| 3010 | + ): |
| 3011 | + model_type = model.config.model_type |
| 3012 | + patcher_for_model_type = { |
| 3013 | + "llama": LlamaModelPatcher, |
| 3014 | + "qwen2": UpdateCausalMaskModelPatcher, |
| 3015 | + "phi3": Phi3ModelPatcher, |
| 3016 | + "internlm2": InternLM2Patcher, |
| 3017 | + } |
| 3018 | + self._internal_patcher = None |
| 3019 | + self._patched_forward = None |
| 3020 | + internal_patcher_cls = patcher_for_model_type.get(model_type) |
| 3021 | + if internal_patcher_cls is not None: |
| 3022 | + self._internal_patcher = internal_patcher_cls(config, model, model_kwargs) |
| 3023 | + self._patched_forward = self._internal_patcher.patched_forward |
| 3024 | + super().__init__(config, model, model_kwargs) |
| 3025 | + |
| 3026 | + @property |
| 3027 | + def patched_forward(self): |
| 3028 | + if self._internal_patcher is not None: |
| 3029 | + return self._internal_patcher.patched_forward |
| 3030 | + return self._patched_forward |
| 3031 | + |
| 3032 | + @patched_forward.setter |
| 3033 | + def patched_forward(self, fn): |
| 3034 | + self._patched_forward = fn |
| 3035 | + if self._internal_patcher is not None: |
| 3036 | + self._internal_patcher.patched_forward = fn |
| 3037 | + |
| 3038 | + def __enter__(self): |
| 3039 | + if is_torch_version(">=", "2.1.0"): |
| 3040 | + if self._model.config.model_type == "qwen2" and self._model.config._attn_implementation != "sdpa": |
| 3041 | + from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES |
| 3042 | + |
| 3043 | + sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"] |
| 3044 | + self._model.config._orig_attn_implementation = self._model.config._attn_implementation |
| 3045 | + self._model.config._attn_implementation = "sdpa" |
| 3046 | + |
| 3047 | + for layer in self._model.model.layers: |
| 3048 | + layer.self_attn._orig_forward = layer.self_attn.forward |
| 3049 | + layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn) |
| 3050 | + |
| 3051 | + if self._model.config.model_type == "llama" and self._model.config._attn_implementation != "sdpa": |
| 3052 | + self._model.config._orig_attn_implementation = self._model.config._attn_implementation |
| 3053 | + self._model.config._attn_implementation = "sdpa" |
| 3054 | + if is_transformers_version("<", "4.47"): |
| 3055 | + from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES |
| 3056 | + |
| 3057 | + sdpa_attn = LLAMA_ATTENTION_CLASSES["sdpa"] |
| 3058 | + for layer in self._model.model.layers: |
| 3059 | + layer.self_attn._orig_forward = layer.self_attn.forward |
| 3060 | + layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn) |
| 3061 | + |
| 3062 | + if self._internal_patcher is not None: |
| 3063 | + return self._internal_patcher.__enter__() |
| 3064 | + return super().__enter__() |
| 3065 | + |
| 3066 | + def __exit__(self, exc_type, exc_value, traceback): |
| 3067 | + if self._internal_patcher: |
| 3068 | + self._internal_patcher.__exit__(exc_type, exc_value, traceback) |
| 3069 | + else: |
| 3070 | + super().__exit__(exc_type, exc_value, traceback) |
| 3071 | + |
| 3072 | + if hasattr(self._model.config, "_orig_attn_implementation"): |
| 3073 | + self._model.config._attn_implementation = self._model.config._orig_attn_implementation |
| 3074 | + for layer in self._model.model.layers: |
| 3075 | + if hasattr(layer.self_attn, "_orig_forward"): |
| 3076 | + layer.self_attn.forward = layer.self_attn._orig_forward |
2995 | 3077 |
|
2996 | 3078 |
|
2997 | 3079 | def llava_vision_embed_forward(self, pixel_values):
|
|
0 commit comments