@@ -718,14 +718,15 @@ def _mistral_update_causal_mask(
718
718
class MistralModelPatcher (DecoderModelPatcher ):
719
719
def __enter__ (self ):
720
720
super ().__enter__ ()
721
- if is_transformers_version (">=" , "4.42.0" ):
721
+ if is_transformers_version (">=" , "4.42.0" ) and is_transformers_version ( "<" , "4.48.0" ) :
722
722
# apply fix https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548
723
723
self ._model .model ._orig_update_causal_mask = self ._model .model ._update_causal_mask
724
724
self ._model .model ._update_causal_mask = types .MethodType (_mistral_update_causal_mask , self ._model .model )
725
725
726
726
else :
727
727
for layer in self ._model .model .layers :
728
- _reinitialize_cos_sin_cached_fp32 (layer .self_attn .rotary_emb )
728
+ if hasattr (layer .self_attn , "rotary_emb" ):
729
+ _reinitialize_cos_sin_cached_fp32 (layer .self_attn .rotary_emb )
729
730
730
731
def __exit__ (self , exc_type , exc_value , traceback ):
731
732
super ().__exit__ (exc_type , exc_value , traceback )
@@ -734,7 +735,7 @@ def __exit__(self, exc_type, exc_value, traceback):
734
735
self ._model .model ._update_causal_mask = self ._model .model ._orig_update_causal_mask
735
736
736
737
for layer in self ._model .model .layers :
737
- if hasattr (layer .self_attn .rotary_emb , "_orig_forward" ):
738
+ if hasattr (layer .self_attn , "rotary_emb" ) and hasattr ( layer . self_attn .rotary_emb , "_orig_forward" ):
738
739
layer .self_attn .rotary_emb .forward = layer .self_attn .rotary_emb ._orig_forward
739
740
740
741
@@ -2493,7 +2494,9 @@ class UpdateCausalMaskModelPatcher(DecoderModelPatcher):
2493
2494
def __enter__ (self ):
2494
2495
super ().__enter__ ()
2495
2496
patch_update_causal_mask (self ._model , "4.42.0" )
2496
- if hasattr (self ._model .model .layers [0 ].self_attn .rotary_emb , "_set_cos_sin_cache" ):
2497
+ if hasattr (self ._model .model .layers [0 ].self_attn , "rotary_emb" ) and hasattr (
2498
+ self ._model .model .layers [0 ].self_attn .rotary_emb , "_set_cos_sin_cache"
2499
+ ):
2497
2500
for layer in self ._model .model .layers :
2498
2501
_reinitialize_cos_sin_cached_fp32 (layer .self_attn .rotary_emb )
2499
2502
@@ -3045,15 +3048,16 @@ def patched_forward(self, fn):
3045
3048
def __enter__ (self ):
3046
3049
if is_torch_version (">=" , "2.1.0" ):
3047
3050
if self ._model .config .model_type == "qwen2" and self ._model .config ._attn_implementation != "sdpa" :
3048
- from transformers .models .qwen2 .modeling_qwen2 import QWEN2_ATTENTION_CLASSES
3051
+ if is_transformers_version ("<" , "4.48" ):
3052
+ from transformers .models .qwen2 .modeling_qwen2 import QWEN2_ATTENTION_CLASSES
3049
3053
3050
- sdpa_attn = QWEN2_ATTENTION_CLASSES ["sdpa" ]
3051
- self ._model .config ._orig_attn_implementation = self ._model .config ._attn_implementation
3052
- self ._model .config ._attn_implementation = "sdpa"
3054
+ sdpa_attn = QWEN2_ATTENTION_CLASSES ["sdpa" ]
3055
+ self ._model .config ._orig_attn_implementation = self ._model .config ._attn_implementation
3056
+ self ._model .config ._attn_implementation = "sdpa"
3053
3057
3054
- for layer in self ._model .model .layers :
3055
- layer .self_attn ._orig_forward = layer .self_attn .forward
3056
- layer .self_attn .forward = types .MethodType (sdpa_attn .forward , layer .self_attn )
3058
+ for layer in self ._model .model .layers :
3059
+ layer .self_attn ._orig_forward = layer .self_attn .forward
3060
+ layer .self_attn .forward = types .MethodType (sdpa_attn .forward , layer .self_attn )
3057
3061
3058
3062
if self ._model .config .model_type == "llama" and self ._model .config ._attn_implementation != "sdpa" :
3059
3063
self ._model .config ._orig_attn_implementation = self ._model .config ._attn_implementation
0 commit comments