@@ -2681,6 +2681,74 @@ def __exit__(self, exc_type, exc_value, traceback):
2681
2681
unpatch_update_causal_mask (self ._model , "gpt_neox_japanese" )
2682
2682
2683
2683
2684
+ def _gpt_neo_sdpa_attn (self , query , key , value , attention_mask = None , head_mask = None ):
2685
+ # Keep the attention weights computation in fp32 to avoid overflow issues
2686
+ query = query .to (torch .float32 )
2687
+ key = key .to (torch .float32 )
2688
+
2689
+ # Apply sliding window masking for local attention layers
2690
+ query_length , key_length = query .size (- 2 ), key .size (- 2 )
2691
+ causal_mask = self .bias [:, :, key_length - query_length : key_length , :key_length ]
2692
+ # different from original for prevent overflow, apply to mask instead of directly to weights
2693
+ mask_value = torch .finfo (torch .float16 ).min
2694
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
2695
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
2696
+ mask_value = torch .tensor (mask_value , dtype = query .dtype ).to (query .device )
2697
+ if attention_mask is None :
2698
+ attention_mask = torch .ones_like (causal_mask )
2699
+ attention_mask = torch .where (causal_mask , attention_mask [:, :, :, : key .shape [- 2 ]], mask_value )
2700
+
2701
+ # Mask heads if we want to
2702
+ if head_mask is not None :
2703
+ attention_mask = attention_mask * head_mask
2704
+
2705
+ attn_output = torch .nn .functional .scaled_dot_product_attention (query , key , value , attn_mask = attention_mask )
2706
+
2707
+ return attn_output , None
2708
+
2709
+
2710
+ def _gpt_neo_attn_forward (
2711
+ self ,
2712
+ hidden_states ,
2713
+ attention_mask = None ,
2714
+ layer_past = None ,
2715
+ head_mask = None ,
2716
+ use_cache = False ,
2717
+ output_attentions = False ,
2718
+ cache_position = None ,
2719
+ ):
2720
+ if output_attentions :
2721
+ self ._attn = self ._orig_attn
2722
+
2723
+ return self ._orig_forward (
2724
+ hidden_states ,
2725
+ attention_mask = attention_mask ,
2726
+ layer_past = layer_past ,
2727
+ head_mask = head_mask ,
2728
+ use_cache = use_cache ,
2729
+ output_attentions = output_attentions ,
2730
+ cache_position = cache_position ,
2731
+ )
2732
+
2733
+
2734
+ class GptNeoModelPatcher (DecoderModelPatcher ):
2735
+ def __enter__ (self ):
2736
+ super ().__enter__ ()
2737
+ if is_transformers_version (">=" , "4.45.0" ) and is_torch_version (">=" , "2.1.0" ):
2738
+ for layer in self ._model .transformer .h :
2739
+ self_attn = layer .attn .attention
2740
+ self_attn ._orig_attn = self_attn ._attn
2741
+ self_attn ._attn = types .MethodType (_gpt_neo_sdpa_attn , self_attn )
2742
+ self_attn ._orig_forward = types .MethodType (_gpt_neo_attn_forward , self_attn )
2743
+
2744
+ def __exit__ (self , exc_type , exc_value , traceback ):
2745
+ super ().__exit__ (exc_type , exc_value , traceback )
2746
+ for layer in self ._model .transformer .h :
2747
+ if hasattr (layer .attn .attention , "_orig_forward" ):
2748
+ layer .attn .attention .forward = layer .attn .attention ._orig_forward
2749
+ layer .attn .attention ._attn = layer .attn .attention ._orig_attn
2750
+
2751
+
2684
2752
class Gemma2ModelPatcher (LlamaModelPatcher ):
2685
2753
def __init__ (
2686
2754
self ,
0 commit comments