@@ -2681,32 +2681,6 @@ def __exit__(self, exc_type, exc_value, traceback):
2681
2681
unpatch_update_causal_mask (self ._model , "gpt_neox_japanese" )
2682
2682
2683
2683
2684
- def _gpt_neo_sdpa_attn (self , query , key , value , attention_mask = None , head_mask = None ):
2685
- # Keep the attention weights computation in fp32 to avoid overflow issues
2686
- query = query .to (torch .float32 )
2687
- key = key .to (torch .float32 )
2688
-
2689
- # Apply sliding window masking for local attention layers
2690
- query_length , key_length = query .size (- 2 ), key .size (- 2 )
2691
- causal_mask = self .bias [:, :, key_length - query_length : key_length , :key_length ]
2692
- # different from original for prevent overflow, apply to mask instead of directly to weights
2693
- mask_value = torch .finfo (torch .float16 ).min
2694
- # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
2695
- # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
2696
- mask_value = torch .tensor (mask_value , dtype = query .dtype ).to (query .device )
2697
- if attention_mask is None :
2698
- attention_mask = torch .ones_like (causal_mask )
2699
- attention_mask = torch .where (causal_mask , attention_mask [:, :, :, : key .shape [- 2 ]], mask_value )
2700
-
2701
- # Mask heads if we want to
2702
- if head_mask is not None :
2703
- attention_mask = attention_mask * head_mask
2704
-
2705
- attn_output = torch .nn .functional .scaled_dot_product_attention (query , key , value , attn_mask = attention_mask )
2706
-
2707
- return attn_output , None
2708
-
2709
-
2710
2684
def _gpt_neo_attn_forward (
2711
2685
self ,
2712
2686
hidden_states ,
@@ -2731,22 +2705,70 @@ def _gpt_neo_attn_forward(
2731
2705
)
2732
2706
2733
2707
2708
+ # Adopted from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py#L185
2709
+ def _gpt_neo_attn_sdpa (
2710
+ self ,
2711
+ query : torch .Tensor ,
2712
+ key : torch .Tensor ,
2713
+ value : torch .Tensor ,
2714
+ attention_mask : Optional [torch .Tensor ] = None ,
2715
+ head_mask : Optional [torch .Tensor ] = None ,
2716
+ ):
2717
+ batch_size = query .shape [0 ]
2718
+
2719
+ mask_value = torch .finfo (torch .float16 ).min
2720
+ mask_value = torch .full ([], mask_value , dtype = value .dtype )
2721
+
2722
+ dropout_p = float (self .config .attention_dropout ) if self .training else 0.0
2723
+ if (batch_size == 1 or self .training ) and self .attention_type == "global" :
2724
+ if query .shape [2 ] > 1 :
2725
+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
2726
+ query , key , value , attn_mask = None , dropout_p = dropout_p , is_causal = True
2727
+ )
2728
+ else :
2729
+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
2730
+ query , key , value , attn_mask = None , dropout_p = dropout_p , is_causal = False , scale = 1.0
2731
+ )
2732
+ else :
2733
+ query_length , key_length = query .size (- 2 ), key .size (- 2 )
2734
+
2735
+ causal_mask = self .bias [:, :, key_length - query_length : key_length , :key_length ]
2736
+
2737
+ causal_mask = torch .where (causal_mask , 0 , mask_value )
2738
+ if batch_size > 1 :
2739
+ # torch.Tensor.expand does no memory copy
2740
+ causal_mask = causal_mask .expand (batch_size , - 1 , - 1 , - 1 )
2741
+
2742
+ if attention_mask is not None :
2743
+ attention_mask = causal_mask + attention_mask
2744
+
2745
+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
2746
+ query , key , value , attn_mask = attention_mask , dropout_p = dropout_p , is_causal = False , scale = 1.0
2747
+ )
2748
+
2749
+ return sdpa_result , None
2750
+
2751
+
2734
2752
class GptNeoModelPatcher (DecoderModelPatcher ):
2735
2753
def __enter__ (self ):
2736
2754
super ().__enter__ ()
2737
2755
if is_transformers_version (">=" , "4.45.0" ) and is_torch_version (">=" , "2.1.0" ):
2756
+ self ._model .config ._orig_attn_implementation = self ._model .config ._attn_implementation
2757
+ self ._model .config ._attn_implementation = "sdpa"
2738
2758
for layer in self ._model .transformer .h :
2739
2759
self_attn = layer .attn .attention
2740
2760
self_attn ._orig_attn = self_attn ._attn
2741
- self_attn ._attn = types .MethodType (_gpt_neo_sdpa_attn , self_attn )
2761
+ self_attn ._attn = types .MethodType (_gpt_neo_attn_sdpa , self_attn )
2742
2762
self_attn ._orig_forward = types .MethodType (_gpt_neo_attn_forward , self_attn )
2743
2763
2744
2764
def __exit__ (self , exc_type , exc_value , traceback ):
2745
2765
super ().__exit__ (exc_type , exc_value , traceback )
2746
- for layer in self ._model .transformer .h :
2747
- if hasattr (layer .attn .attention , "_orig_forward" ):
2748
- layer .attn .attention .forward = layer .attn .attention ._orig_forward
2749
- layer .attn .attention ._attn = layer .attn .attention ._orig_attn
2766
+ if hasattr (self ._model .config , "_orig_attn_implementation" ):
2767
+ self ._model .config ._attn_implementation = self ._model .config ._orig_attn_implementation
2768
+ for layer in self ._model .transformer .h :
2769
+ for layer in self ._model .transformer .h :
2770
+ layer .attn .attention .forward = layer .attn .attention ._orig_forward
2771
+ layer .attn .attention ._attn = layer .attn .attention ._orig_attn
2750
2772
2751
2773
2752
2774
class Gemma2ModelPatcher (LlamaModelPatcher ):
0 commit comments