|
21 | 21 |
|
22 | 22 | import torch
|
23 | 23 | import torch.nn.functional as F
|
| 24 | +from transformers import PreTrainedModel, TFPreTrainedModel |
24 | 25 | from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
|
25 | 26 | from transformers.utils import is_tf_available
|
26 | 27 |
|
| 28 | +from optimum.exporters.onnx.base import OnnxConfig |
27 | 29 | from optimum.exporters.onnx.model_patcher import (
|
28 | 30 | DecoderModelPatcher,
|
29 | 31 | ModelPatcher,
|
@@ -114,18 +116,20 @@ def patch_model_with_bettertransformer(model):
|
114 | 116 | return model
|
115 | 117 |
|
116 | 118 |
|
117 |
| -def patch_update_causal_mask(model, transformers_version, inner_model_name="model", patch_fn=None): |
| 119 | +def patch_update_causal_mask( |
| 120 | + model, transformers_version, inner_model_name="model", patch_fn=None, patch_extrnal_model=False |
| 121 | +): |
118 | 122 | if is_transformers_version(">=", transformers_version):
|
119 |
| - inner_model = getattr(model, inner_model_name, None) |
| 123 | + inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model |
120 | 124 | if inner_model is not None:
|
121 | 125 | if hasattr(inner_model, "_update_causal_mask"):
|
122 | 126 | inner_model._orig_update_causal_mask = inner_model._update_causal_mask
|
123 | 127 | patch_fn = patch_fn or _llama_gemma_update_causal_mask
|
124 | 128 | inner_model._update_causal_mask = types.MethodType(patch_fn, inner_model)
|
125 | 129 |
|
126 | 130 |
|
127 |
| -def unpatch_update_causal_mask(model, inner_model_name="model"): |
128 |
| - inner_model = getattr(model, inner_model_name, None) |
| 131 | +def unpatch_update_causal_mask(model, inner_model_name="model", patch_extrnal_model=False): |
| 132 | + inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model |
129 | 133 | if inner_model is not None and hasattr(inner_model, "._orig_update_causal_mask"):
|
130 | 134 | inner_model._update_causal_mask = inner_model._orig_update_causal_mask
|
131 | 135 |
|
@@ -3791,3 +3795,29 @@ def patched_forward(*args, **kwargs):
|
3791 | 3795 | model.forward = patched_forward
|
3792 | 3796 |
|
3793 | 3797 | super().__init__(config, model, model_kwargs)
|
| 3798 | + |
| 3799 | + |
| 3800 | +class SanaTextEncoderModelPatcher(ModelPatcher): |
| 3801 | + def __enter__(self): |
| 3802 | + super().__enter__() |
| 3803 | + patch_update_causal_mask(self._model, "4.39.0", None, patch_extrnal_model=True) |
| 3804 | + |
| 3805 | + if self._model.config._attn_implementation != "sdpa": |
| 3806 | + self._model.config._orig_attn_implementation = self._model.config._attn_implementation |
| 3807 | + self._model.config._attn_implementation = "sdpa" |
| 3808 | + if is_transformers_version("<", "4.47.0"): |
| 3809 | + from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES |
| 3810 | + |
| 3811 | + sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"] |
| 3812 | + for layer in self._model.layers: |
| 3813 | + layer.self_attn._orig_forward = layer.self_attn.forward |
| 3814 | + layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn) |
| 3815 | + |
| 3816 | + def __exit__(self, exc_type, exc_value, traceback): |
| 3817 | + super().__exit__(exc_type, exc_value, traceback) |
| 3818 | + unpatch_update_causal_mask(self._model, None, True) |
| 3819 | + if hasattr(self._model.config, "_orig_attn_implementation"): |
| 3820 | + self._model.config._attn_implementation = self._model.config._orig_attn_implementation |
| 3821 | + for layer in self._model.layers: |
| 3822 | + if hasattr(layer.self_attn, "_orig_forward"): |
| 3823 | + layer.self_attn.forward = layer.self_attn._orig_forward |
0 commit comments