@@ -421,9 +421,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c
421
421
offset = 0
422
422
mask_shape = attention_mask .shape
423
423
mask_slice = (attention_mask .eq (0.0 )).to (dtype = dtype ) * min_dtype
424
- causal_mask [
425
- : mask_shape [ 0 ], : mask_shape [ 1 ], offset : mask_shape [ 2 ] + offset , : mask_shape [ 3 ]
426
- ] = mask_slice
424
+ causal_mask [: mask_shape [ 0 ], : mask_shape [ 1 ], offset : mask_shape [ 2 ] + offset , : mask_shape [ 3 ]] = (
425
+ mask_slice
426
+ )
427
427
428
428
if (
429
429
self .config ._attn_implementation == "sdpa"
@@ -2058,9 +2058,9 @@ def _dbrx_update_causal_mask_legacy(
2058
2058
offset = 0
2059
2059
mask_shape = attention_mask .shape
2060
2060
mask_slice = (attention_mask .eq (0.0 )).to (dtype = dtype ) * min_dtype
2061
- causal_mask [
2062
- : mask_shape [ 0 ], : mask_shape [ 1 ], offset : mask_shape [ 2 ] + offset , : mask_shape [ 3 ]
2063
- ] = mask_slice
2061
+ causal_mask [: mask_shape [ 0 ], : mask_shape [ 1 ], offset : mask_shape [ 2 ] + offset , : mask_shape [ 3 ]] = (
2062
+ mask_slice
2063
+ )
2064
2064
2065
2065
if (
2066
2066
self .config ._attn_implementation == "sdpa"
@@ -2710,40 +2710,6 @@ def patched_forward(*args, **kwargs):
2710
2710
2711
2711
self .patched_forward = patched_forward
2712
2712
2713
- def __enter__ (self ):
2714
- super ().__enter__ ()
2715
-
2716
- if is_transformers_version (">=" , "4.47.0" ):
2717
- from transformers .models .gemma2 .modeling_gemma2 import GEMMA2_ATTENTION_FUNCTION
2718
-
2719
- GEMMA2_ATTENTION_FUNCTION ["original_eager" ] = GEMMA2_ATTENTION_FUNCTION ["eager" ]
2720
- GEMMA2_ATTENTION_FUNCTION ["eager" ] = GEMMA2_ATTENTION_FUNCTION ["sdpa" ]
2721
-
2722
- elif is_transformers_version (">=" , "4.45.0" ):
2723
- from transformers .models .gemma2 .modeling_gemma2 import GEMMA2_ATTENTION_CLASSES
2724
-
2725
- sdpa_attn = GEMMA2_ATTENTION_CLASSES ["sdpa" ]
2726
- eager_attn = GEMMA2_ATTENTION_CLASSES ["eager" ]
2727
-
2728
- for layer in self ._model .model .layers :
2729
- if isinstance (layer .self_attn , eager_attn ):
2730
- layer .self_attn ._orig_forward = layer .self_attn .forward
2731
- layer .self_attn .forward = types .MethodType (sdpa_attn .forward , layer .self_attn )
2732
-
2733
- def __exit__ (self , exc_type , exc_value , traceback ):
2734
- super ().__exit__ (exc_type , exc_value , traceback )
2735
-
2736
- if is_transformers_version (">=" , "4.47.0" ):
2737
- from transformers .models .gemma2 .modeling_gemma2 import GEMMA2_ATTENTION_FUNCTION
2738
-
2739
- GEMMA2_ATTENTION_FUNCTION ["eager" ] = GEMMA2_ATTENTION_FUNCTION ["original_eager" ]
2740
- del GEMMA2_ATTENTION_FUNCTION ["original_eager" ]
2741
-
2742
- elif is_transformers_version (">=" , "4.45.0" ):
2743
- for layer in self ._model .model .layers :
2744
- if hasattr (layer .self_attn , "_orig_forward" ):
2745
- layer .self_attn .forward = layer .self_attn ._orig_forward
2746
-
2747
2713
2748
2714
def _decilm_attn_forward (
2749
2715
self ,
0 commit comments