@@ -3601,3 +3601,93 @@ def __exit__(self, exc_type, exc_value, traceback):
3601
3601
for block in self ._model .blocks :
3602
3602
block .forward = block ._orig_forward
3603
3603
block .attn .forward = block .attn ._orig_forward
3604
+
3605
+
3606
+ # copied from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L401
3607
+ def gpt_bigcode_attn (self , query , key , value , attention_mask = None , head_mask = None ):
3608
+ if head_mask is not None :
3609
+ # The super dispatch is done in the forward.
3610
+ raise ValueError ("PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository." )
3611
+
3612
+ scale = None
3613
+ if not self .scale_attn_weights :
3614
+ scale = 1
3615
+
3616
+ # MQA models: (batch_size, query_length, num_heads * head_dim)
3617
+ # MHA models: (batch_size, num_heads, query_length, head_dim)
3618
+ query_shape = query .shape
3619
+ batch_size = query_shape [0 ]
3620
+ key .shape [- 2 ]
3621
+
3622
+ if self .multi_query :
3623
+ query_length = query_shape [1 ]
3624
+
3625
+ # SDPA requires the dimension [..., sequence_length, head_dim].
3626
+ query = query .view (batch_size , query_length , self .num_heads , self .head_dim ).transpose (1 , 2 )
3627
+
3628
+ # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
3629
+ key = key .unsqueeze (1 )
3630
+ value = value .unsqueeze (1 )
3631
+
3632
+ # Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend
3633
+ # and flash attention backend (No available kernel. Aborting execution.) from the shapes
3634
+ # query = [batch_size, num_heads, query_length, head_dim]
3635
+ # key = [batch_size, 1, past_length, head_dim]
3636
+ # value = [batch_size, 1, past_length, head_dim]
3637
+ #
3638
+ # torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check.
3639
+ if is_torch_version (">=" , "2.2.0" ):
3640
+ key = key .expand (- 1 , self .num_heads , - 1 , - 1 )
3641
+ value = value .expand (- 1 , self .num_heads , - 1 , - 1 )
3642
+ else :
3643
+ query_length = query_shape [- 1 ]
3644
+
3645
+ # See the comment above.
3646
+ if query .device .type == "cuda" and attention_mask is not None :
3647
+ query = query .contiguous ()
3648
+ key = key .contiguous ()
3649
+ value = value .contiguous ()
3650
+
3651
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
3652
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
3653
+ # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not
3654
+ # create a causal mask in case query_length == 1.
3655
+ is_causal = True if self .is_causal and attention_mask is None and query_length > 1 else False
3656
+ # different from original, due to loading model weights in original format transformer.wte dtype may be different from query dtype
3657
+ if attention_mask is not None :
3658
+ attention_mask = attention_mask .to (query .dtype )
3659
+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
3660
+ query ,
3661
+ key ,
3662
+ value ,
3663
+ attn_mask = attention_mask ,
3664
+ dropout_p = self .attn_pdrop if self .training else 0.0 ,
3665
+ is_causal = is_causal ,
3666
+ scale = scale ,
3667
+ )
3668
+
3669
+ if self .multi_query :
3670
+ # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
3671
+ sdpa_result = sdpa_result .transpose (1 , 2 )
3672
+
3673
+ # Reshape is kind of expensive here, as it does a memory copy,
3674
+ # but I did not manage to make away without it (logits do not match when using view)
3675
+ # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
3676
+ sdpa_result = sdpa_result .reshape (query_shape )
3677
+
3678
+ return sdpa_result , None
3679
+
3680
+
3681
+ class GptBigCodeModelPatcher (DecoderModelPatcher ):
3682
+ def __enter__ (self ):
3683
+ super ().__enter__ ()
3684
+ if getattr (self ._model .config , "_attn_implementation" , "eager" ) == "sdpa" :
3685
+ for layer in self ._model .transformer .h :
3686
+ layer .attn ._orig_attn = layer .attn ._attn
3687
+ layer .attn ._attn = types .MethodType (gpt_bigcode_attn , layer .attn )
3688
+
3689
+ def __exit__ (self , exc_type , exc_value , traceback ):
3690
+ super ().__exit__ (exc_type , exc_value , traceback )
3691
+ if getattr (self ._model .config , "_attn_implementation" , "eager" ) == "sdpa" :
3692
+ for layer in self ._model .transformer .h :
3693
+ layer .attn ._attn = layer .attn ._orig_attn
0 commit comments