@@ -2654,6 +2654,74 @@ def __exit__(self, exc_type, exc_value, traceback):
2654
2654
unpatch_update_causal_mask (self ._model , "gpt_neox_japanese" )
2655
2655
2656
2656
2657
+ def _gpt_neo_sdpa_attn (self , query , key , value , attention_mask = None , head_mask = None ):
2658
+ # Keep the attention weights computation in fp32 to avoid overflow issues
2659
+ query = query .to (torch .float32 )
2660
+ key = key .to (torch .float32 )
2661
+
2662
+ # Apply sliding window masking for local attention layers
2663
+ query_length , key_length = query .size (- 2 ), key .size (- 2 )
2664
+ causal_mask = self .bias [:, :, key_length - query_length : key_length , :key_length ]
2665
+ # different from original for prevent overflow, apply to mask instead of directly to weights
2666
+ mask_value = torch .finfo (torch .float16 ).min
2667
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
2668
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
2669
+ mask_value = torch .tensor (mask_value , dtype = query .dtype ).to (query .device )
2670
+ if attention_mask is None :
2671
+ attention_mask = torch .ones_like (causal_mask )
2672
+ attention_mask = torch .where (causal_mask , attention_mask [:, :, :, : key .shape [- 2 ]], mask_value )
2673
+
2674
+ # Mask heads if we want to
2675
+ if head_mask is not None :
2676
+ attention_mask = attention_mask * head_mask
2677
+
2678
+ attn_output = torch .nn .functional .scaled_dot_product_attention (query , key , value , attn_mask = attention_mask )
2679
+
2680
+ return attn_output , None
2681
+
2682
+
2683
+ def _gpt_neo_attn_forward (
2684
+ self ,
2685
+ hidden_states ,
2686
+ attention_mask = None ,
2687
+ layer_past = None ,
2688
+ head_mask = None ,
2689
+ use_cache = False ,
2690
+ output_attentions = False ,
2691
+ cache_position = None ,
2692
+ ):
2693
+ if output_attentions :
2694
+ self ._attn = self ._orig_attn
2695
+
2696
+ return self ._orig_forward (
2697
+ hidden_states ,
2698
+ attention_mask = attention_mask ,
2699
+ layer_past = layer_past ,
2700
+ head_mask = head_mask ,
2701
+ use_cache = use_cache ,
2702
+ output_attentions = output_attentions ,
2703
+ cache_position = cache_position ,
2704
+ )
2705
+
2706
+
2707
+ class GptNeoModelPatcher (DecoderModelPatcher ):
2708
+ def __enter__ (self ):
2709
+ super ().__enter__ ()
2710
+ if is_transformers_version (">=" , "4.45.0" ) and is_torch_version (">=" , "2.1.0" ):
2711
+ for layer in self ._model .transformer .h :
2712
+ self_attn = layer .attn .attention
2713
+ self_attn ._orig_attn = self_attn ._attn
2714
+ self_attn ._attn = types .MethodType (_gpt_neo_sdpa_attn , self_attn )
2715
+ self_attn ._orig_forward = types .MethodType (_gpt_neo_attn_forward , self_attn )
2716
+
2717
+ def __exit__ (self , exc_type , exc_value , traceback ):
2718
+ super ().__exit__ (exc_type , exc_value , traceback )
2719
+ for layer in self ._model .transformer .h :
2720
+ if hasattr (layer .attn .attention , "_orig_forward" ):
2721
+ layer .attn .attention .forward = layer .attn .attention ._orig_forward
2722
+ layer .attn .attention ._attn = layer .attn .attention ._orig_attn
2723
+
2724
+
2657
2725
class Gemma2ModelPatcher (LlamaModelPatcher ):
2658
2726
def __init__ (
2659
2727
self ,
0 commit comments