@@ -3650,3 +3650,93 @@ def __exit__(self, exc_type, exc_value, traceback):
3650
3650
block_sparse_moe .router .forward = block_sparse_moe .router ._orig_forward
3651
3651
block_sparse_moe .input_linear .forward = block_sparse_moe .input_linear ._orig_forward
3652
3652
block_sparse_moe .output_linear .forward = block_sparse_moe .output_linear ._orig_forward
3653
+
3654
+
3655
+ # copied from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L401
3656
+ def gpt_bigcode_attn (self , query , key , value , attention_mask = None , head_mask = None ):
3657
+ if head_mask is not None :
3658
+ # The super dispatch is done in the forward.
3659
+ raise ValueError ("PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository." )
3660
+
3661
+ scale = None
3662
+ if not self .scale_attn_weights :
3663
+ scale = 1
3664
+
3665
+ # MQA models: (batch_size, query_length, num_heads * head_dim)
3666
+ # MHA models: (batch_size, num_heads, query_length, head_dim)
3667
+ query_shape = query .shape
3668
+ batch_size = query_shape [0 ]
3669
+ key .shape [- 2 ]
3670
+
3671
+ if self .multi_query :
3672
+ query_length = query_shape [1 ]
3673
+
3674
+ # SDPA requires the dimension [..., sequence_length, head_dim].
3675
+ query = query .view (batch_size , query_length , self .num_heads , self .head_dim ).transpose (1 , 2 )
3676
+
3677
+ # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
3678
+ key = key .unsqueeze (1 )
3679
+ value = value .unsqueeze (1 )
3680
+
3681
+ # Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend
3682
+ # and flash attention backend (No available kernel. Aborting execution.) from the shapes
3683
+ # query = [batch_size, num_heads, query_length, head_dim]
3684
+ # key = [batch_size, 1, past_length, head_dim]
3685
+ # value = [batch_size, 1, past_length, head_dim]
3686
+ #
3687
+ # torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check.
3688
+ if is_torch_version (">=" , "2.2.0" ):
3689
+ key = key .expand (- 1 , self .num_heads , - 1 , - 1 )
3690
+ value = value .expand (- 1 , self .num_heads , - 1 , - 1 )
3691
+ else :
3692
+ query_length = query_shape [- 1 ]
3693
+
3694
+ # See the comment above.
3695
+ if query .device .type == "cuda" and attention_mask is not None :
3696
+ query = query .contiguous ()
3697
+ key = key .contiguous ()
3698
+ value = value .contiguous ()
3699
+
3700
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
3701
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
3702
+ # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not
3703
+ # create a causal mask in case query_length == 1.
3704
+ is_causal = True if self .is_causal and attention_mask is None and query_length > 1 else False
3705
+ # different from original, due to loading model weights in original format transformer.wte dtype may be different from query dtype
3706
+ if attention_mask is not None :
3707
+ attention_mask = attention_mask .to (query .dtype )
3708
+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
3709
+ query ,
3710
+ key ,
3711
+ value ,
3712
+ attn_mask = attention_mask ,
3713
+ dropout_p = self .attn_pdrop if self .training else 0.0 ,
3714
+ is_causal = is_causal ,
3715
+ scale = scale ,
3716
+ )
3717
+
3718
+ if self .multi_query :
3719
+ # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
3720
+ sdpa_result = sdpa_result .transpose (1 , 2 )
3721
+
3722
+ # Reshape is kind of expensive here, as it does a memory copy,
3723
+ # but I did not manage to make away without it (logits do not match when using view)
3724
+ # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
3725
+ sdpa_result = sdpa_result .reshape (query_shape )
3726
+
3727
+ return sdpa_result , None
3728
+
3729
+
3730
+ class GptBigCodeModelPatcher (DecoderModelPatcher ):
3731
+ def __enter__ (self ):
3732
+ super ().__enter__ ()
3733
+ if getattr (self ._model .config , "_attn_implementation" , "eager" ) == "sdpa" :
3734
+ for layer in self ._model .transformer .h :
3735
+ layer .attn ._orig_attn = layer .attn ._attn
3736
+ layer .attn ._attn = types .MethodType (gpt_bigcode_attn , layer .attn )
3737
+
3738
+ def __exit__ (self , exc_type , exc_value , traceback ):
3739
+ super ().__exit__ (exc_type , exc_value , traceback )
3740
+ if getattr (self ._model .config , "_attn_implementation" , "eager" ) == "sdpa" :
3741
+ for layer in self ._model .transformer .h :
3742
+ layer .attn ._attn = layer .attn ._orig_attn
0 commit comments