@@ -2654,32 +2654,6 @@ def __exit__(self, exc_type, exc_value, traceback):
2654
2654
unpatch_update_causal_mask (self ._model , "gpt_neox_japanese" )
2655
2655
2656
2656
2657
- def _gpt_neo_sdpa_attn (self , query , key , value , attention_mask = None , head_mask = None ):
2658
- # Keep the attention weights computation in fp32 to avoid overflow issues
2659
- query = query .to (torch .float32 )
2660
- key = key .to (torch .float32 )
2661
-
2662
- # Apply sliding window masking for local attention layers
2663
- query_length , key_length = query .size (- 2 ), key .size (- 2 )
2664
- causal_mask = self .bias [:, :, key_length - query_length : key_length , :key_length ]
2665
- # different from original for prevent overflow, apply to mask instead of directly to weights
2666
- mask_value = torch .finfo (torch .float16 ).min
2667
- # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
2668
- # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
2669
- mask_value = torch .tensor (mask_value , dtype = query .dtype ).to (query .device )
2670
- if attention_mask is None :
2671
- attention_mask = torch .ones_like (causal_mask )
2672
- attention_mask = torch .where (causal_mask , attention_mask [:, :, :, : key .shape [- 2 ]], mask_value )
2673
-
2674
- # Mask heads if we want to
2675
- if head_mask is not None :
2676
- attention_mask = attention_mask * head_mask
2677
-
2678
- attn_output = torch .nn .functional .scaled_dot_product_attention (query , key , value , attn_mask = attention_mask )
2679
-
2680
- return attn_output , None
2681
-
2682
-
2683
2657
def _gpt_neo_attn_forward (
2684
2658
self ,
2685
2659
hidden_states ,
@@ -2704,22 +2678,70 @@ def _gpt_neo_attn_forward(
2704
2678
)
2705
2679
2706
2680
2681
+ # Adopted from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py#L185
2682
+ def _gpt_neo_attn_sdpa (
2683
+ self ,
2684
+ query : torch .Tensor ,
2685
+ key : torch .Tensor ,
2686
+ value : torch .Tensor ,
2687
+ attention_mask : Optional [torch .Tensor ] = None ,
2688
+ head_mask : Optional [torch .Tensor ] = None ,
2689
+ ):
2690
+ batch_size = query .shape [0 ]
2691
+
2692
+ mask_value = torch .finfo (torch .float16 ).min
2693
+ mask_value = torch .full ([], mask_value , dtype = value .dtype )
2694
+
2695
+ dropout_p = float (self .config .attention_dropout ) if self .training else 0.0
2696
+ if (batch_size == 1 or self .training ) and self .attention_type == "global" :
2697
+ if query .shape [2 ] > 1 :
2698
+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
2699
+ query , key , value , attn_mask = None , dropout_p = dropout_p , is_causal = True
2700
+ )
2701
+ else :
2702
+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
2703
+ query , key , value , attn_mask = None , dropout_p = dropout_p , is_causal = False , scale = 1.0
2704
+ )
2705
+ else :
2706
+ query_length , key_length = query .size (- 2 ), key .size (- 2 )
2707
+
2708
+ causal_mask = self .bias [:, :, key_length - query_length : key_length , :key_length ]
2709
+
2710
+ causal_mask = torch .where (causal_mask , 0 , mask_value )
2711
+ if batch_size > 1 :
2712
+ # torch.Tensor.expand does no memory copy
2713
+ causal_mask = causal_mask .expand (batch_size , - 1 , - 1 , - 1 )
2714
+
2715
+ if attention_mask is not None :
2716
+ attention_mask = causal_mask + attention_mask
2717
+
2718
+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
2719
+ query , key , value , attn_mask = attention_mask , dropout_p = dropout_p , is_causal = False , scale = 1.0
2720
+ )
2721
+
2722
+ return sdpa_result , None
2723
+
2724
+
2707
2725
class GptNeoModelPatcher (DecoderModelPatcher ):
2708
2726
def __enter__ (self ):
2709
2727
super ().__enter__ ()
2710
2728
if is_transformers_version (">=" , "4.45.0" ) and is_torch_version (">=" , "2.1.0" ):
2729
+ self ._model .config ._orig_attn_implementation = self ._model .config ._attn_implementation
2730
+ self ._model .config ._attn_implementation = "sdpa"
2711
2731
for layer in self ._model .transformer .h :
2712
2732
self_attn = layer .attn .attention
2713
2733
self_attn ._orig_attn = self_attn ._attn
2714
- self_attn ._attn = types .MethodType (_gpt_neo_sdpa_attn , self_attn )
2734
+ self_attn ._attn = types .MethodType (_gpt_neo_attn_sdpa , self_attn )
2715
2735
self_attn ._orig_forward = types .MethodType (_gpt_neo_attn_forward , self_attn )
2716
2736
2717
2737
def __exit__ (self , exc_type , exc_value , traceback ):
2718
2738
super ().__exit__ (exc_type , exc_value , traceback )
2719
- for layer in self ._model .transformer .h :
2720
- if hasattr (layer .attn .attention , "_orig_forward" ):
2721
- layer .attn .attention .forward = layer .attn .attention ._orig_forward
2722
- layer .attn .attention ._attn = layer .attn .attention ._orig_attn
2739
+ if hasattr (self ._model .config , "_orig_attn_implementation" ):
2740
+ self ._model .config ._attn_implementation = self ._model .config ._orig_attn_implementation
2741
+ for layer in self ._model .transformer .h :
2742
+ for layer in self ._model .transformer .h :
2743
+ layer .attn .attention .forward = layer .attn .attention ._orig_forward
2744
+ layer .attn .attention ._attn = layer .attn .attention ._orig_attn
2723
2745
2724
2746
2725
2747
class Gemma2ModelPatcher (LlamaModelPatcher ):
0 commit comments