@@ -2681,6 +2681,96 @@ def __exit__(self, exc_type, exc_value, traceback):
2681
2681
unpatch_update_causal_mask (self ._model , "gpt_neox_japanese" )
2682
2682
2683
2683
2684
+ def _gpt_neo_attn_forward (
2685
+ self ,
2686
+ hidden_states ,
2687
+ attention_mask = None ,
2688
+ layer_past = None ,
2689
+ head_mask = None ,
2690
+ use_cache = False ,
2691
+ output_attentions = False ,
2692
+ cache_position = None ,
2693
+ ):
2694
+ if output_attentions :
2695
+ self ._attn = self ._orig_attn
2696
+
2697
+ return self ._orig_forward (
2698
+ hidden_states ,
2699
+ attention_mask = attention_mask ,
2700
+ layer_past = layer_past ,
2701
+ head_mask = head_mask ,
2702
+ use_cache = use_cache ,
2703
+ output_attentions = output_attentions ,
2704
+ cache_position = cache_position ,
2705
+ )
2706
+
2707
+
2708
+ # Adopted from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py#L185
2709
+ def _gpt_neo_attn_sdpa (
2710
+ self ,
2711
+ query : torch .Tensor ,
2712
+ key : torch .Tensor ,
2713
+ value : torch .Tensor ,
2714
+ attention_mask : Optional [torch .Tensor ] = None ,
2715
+ head_mask : Optional [torch .Tensor ] = None ,
2716
+ ):
2717
+ batch_size = query .shape [0 ]
2718
+
2719
+ mask_value = torch .finfo (torch .float16 ).min
2720
+ mask_value = torch .full ([], mask_value , dtype = value .dtype )
2721
+
2722
+ dropout_p = float (self .config .attention_dropout ) if self .training else 0.0
2723
+ if (batch_size == 1 or self .training ) and self .attention_type == "global" :
2724
+ if query .shape [2 ] > 1 :
2725
+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
2726
+ query , key , value , attn_mask = None , dropout_p = dropout_p , is_causal = True
2727
+ )
2728
+ else :
2729
+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
2730
+ query , key , value , attn_mask = None , dropout_p = dropout_p , is_causal = False , scale = 1.0
2731
+ )
2732
+ else :
2733
+ query_length , key_length = query .size (- 2 ), key .size (- 2 )
2734
+
2735
+ causal_mask = self .bias [:, :, key_length - query_length : key_length , :key_length ]
2736
+
2737
+ causal_mask = torch .where (causal_mask , 0 , mask_value )
2738
+ if batch_size > 1 :
2739
+ # torch.Tensor.expand does no memory copy
2740
+ causal_mask = causal_mask .expand (batch_size , - 1 , - 1 , - 1 )
2741
+
2742
+ if attention_mask is not None :
2743
+ attention_mask = causal_mask + attention_mask
2744
+
2745
+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
2746
+ query , key , value , attn_mask = attention_mask , dropout_p = dropout_p , is_causal = False , scale = 1.0
2747
+ )
2748
+
2749
+ return sdpa_result , None
2750
+
2751
+
2752
+ class GptNeoModelPatcher (DecoderModelPatcher ):
2753
+ def __enter__ (self ):
2754
+ super ().__enter__ ()
2755
+ if is_transformers_version (">=" , "4.45.0" ) and is_torch_version (">=" , "2.1.0" ):
2756
+ self ._model .config ._orig_attn_implementation = self ._model .config ._attn_implementation
2757
+ self ._model .config ._attn_implementation = "sdpa"
2758
+ for layer in self ._model .transformer .h :
2759
+ self_attn = layer .attn .attention
2760
+ self_attn ._orig_attn = self_attn ._attn
2761
+ self_attn ._attn = types .MethodType (_gpt_neo_attn_sdpa , self_attn )
2762
+ self_attn ._orig_forward = types .MethodType (_gpt_neo_attn_forward , self_attn )
2763
+
2764
+ def __exit__ (self , exc_type , exc_value , traceback ):
2765
+ super ().__exit__ (exc_type , exc_value , traceback )
2766
+ if hasattr (self ._model .config , "_orig_attn_implementation" ):
2767
+ self ._model .config ._attn_implementation = self ._model .config ._orig_attn_implementation
2768
+ for layer in self ._model .transformer .h :
2769
+ for layer in self ._model .transformer .h :
2770
+ layer .attn .attention .forward = layer .attn .attention ._orig_forward
2771
+ layer .attn .attention ._attn = layer .attn .attention ._orig_attn
2772
+
2773
+
2684
2774
class Gemma2ModelPatcher (LlamaModelPatcher ):
2685
2775
def __init__ (
2686
2776
self ,
0 commit comments