@@ -421,9 +421,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c
421
421
offset = 0
422
422
mask_shape = attention_mask .shape
423
423
mask_slice = (attention_mask .eq (0.0 )).to (dtype = dtype ) * min_dtype
424
- causal_mask [
425
- : mask_shape [ 0 ], : mask_shape [ 1 ], offset : mask_shape [ 2 ] + offset , : mask_shape [ 3 ]
426
- ] = mask_slice
424
+ causal_mask [: mask_shape [ 0 ], : mask_shape [ 1 ], offset : mask_shape [ 2 ] + offset , : mask_shape [ 3 ]] = (
425
+ mask_slice
426
+ )
427
427
428
428
if (
429
429
self .config ._attn_implementation == "sdpa"
@@ -2058,9 +2058,9 @@ def _dbrx_update_causal_mask_legacy(
2058
2058
offset = 0
2059
2059
mask_shape = attention_mask .shape
2060
2060
mask_slice = (attention_mask .eq (0.0 )).to (dtype = dtype ) * min_dtype
2061
- causal_mask [
2062
- : mask_shape [ 0 ], : mask_shape [ 1 ], offset : mask_shape [ 2 ] + offset , : mask_shape [ 3 ]
2063
- ] = mask_slice
2061
+ causal_mask [: mask_shape [ 0 ], : mask_shape [ 1 ], offset : mask_shape [ 2 ] + offset , : mask_shape [ 3 ]] = (
2062
+ mask_slice
2063
+ )
2064
2064
2065
2065
if (
2066
2066
self .config ._attn_implementation == "sdpa"
@@ -2712,7 +2712,14 @@ def patched_forward(*args, **kwargs):
2712
2712
2713
2713
def __enter__ (self ):
2714
2714
super ().__enter__ ()
2715
- if is_transformers_version (">=" , "4.45.0" ):
2715
+
2716
+ if is_transformers_version (">=" , "4.47.0" ):
2717
+ from transformers .models .gemma2 .modeling_gemma2 import GEMMA2_ATTENTION_FUNCTION
2718
+
2719
+ GEMMA2_ATTENTION_FUNCTION ["original_eager" ] = GEMMA2_ATTENTION_FUNCTION ["eager" ]
2720
+ GEMMA2_ATTENTION_FUNCTION ["eager" ] = GEMMA2_ATTENTION_FUNCTION ["sdpa" ]
2721
+
2722
+ elif is_transformers_version (">=" , "4.45.0" ):
2716
2723
from transformers .models .gemma2 .modeling_gemma2 import GEMMA2_ATTENTION_CLASSES
2717
2724
2718
2725
sdpa_attn = GEMMA2_ATTENTION_CLASSES ["sdpa" ]
@@ -2725,7 +2732,14 @@ def __enter__(self):
2725
2732
2726
2733
def __exit__ (self , exc_type , exc_value , traceback ):
2727
2734
super ().__exit__ (exc_type , exc_value , traceback )
2728
- if is_transformers_version (">=" , "4.45.0" ):
2735
+
2736
+ if is_transformers_version (">=" , "4.47.0" ):
2737
+ from transformers .models .gemma2 .modeling_gemma2 import GEMMA2_ATTENTION_FUNCTION
2738
+
2739
+ GEMMA2_ATTENTION_FUNCTION ["eager" ] = GEMMA2_ATTENTION_FUNCTION ["original_eager" ]
2740
+ del GEMMA2_ATTENTION_FUNCTION ["original_eager" ]
2741
+
2742
+ elif is_transformers_version (">=" , "4.45.0" ):
2729
2743
for layer in self ._model .model .layers :
2730
2744
if hasattr (layer .self_attn , "_orig_forward" ):
2731
2745
layer .self_attn .forward = layer .self_attn ._orig_forward
0 commit comments