@@ -722,7 +722,18 @@ def _mpt_attention_forward(
722
722
else :
723
723
past_key_value = (key_states , value_states )
724
724
725
- attention_mask_sdpa = torch .ones (attention_mask .shape , dtype = query_states .dtype )
725
+ key_length = key_states .shape [- 2 ]
726
+ query_length = seq_length if past_key_value is None else seq_length + past_key_value [0 ].shape [2 ]
727
+ attention_mask_sdpa = torch .ones (
728
+ (query_states .shape [0 ], query_states .shape [1 ], query_states .shape [2 ], key_states .shape [2 ]),
729
+ dtype = query_states .dtype ,
730
+ )
731
+ if position_bias is not None :
732
+ position_bias_query_index = max (0 , position_bias .size (1 ) - query_length )
733
+ position_bias_key_index = max (0 , position_bias .size (2 ) - key_length )
734
+
735
+ position_bias = position_bias [:, position_bias_query_index :, position_bias_key_index :]
736
+ attention_mask_sdpa += position_bias
726
737
attention_mask_sdpa .masked_fill_ (attention_mask , torch .finfo (query_states .dtype ).min )
727
738
context_states = torch .nn .functional .scaled_dot_product_attention (
728
739
query_states ,
@@ -732,6 +743,7 @@ def _mpt_attention_forward(
732
743
dropout_p = self .attn_dropout_p ,
733
744
scale = self .softmax_scale ,
734
745
)
746
+
735
747
context_states = context_states .permute (0 , 2 , 1 , 3 ).contiguous ().view (batch_size , seq_length , - 1 )
736
748
attn_output = self .out_proj (context_states )
737
749
@@ -764,17 +776,47 @@ def _internlm_attention_forward(
764
776
use_cache : bool = False ,
765
777
** kwargs ,
766
778
) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
767
- from transformers .models .llama .modeling_llama import apply_rotary_pos_emb , repeat_kv
779
+ # from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
780
+ from einops import rearrange
781
+
782
+ def rotate_half (x ):
783
+ """Rotates half the hidden dims of the input."""
784
+ x1 = x [..., : x .shape [- 1 ] // 2 ]
785
+ x2 = x [..., x .shape [- 1 ] // 2 :]
786
+ return torch .cat ((- x2 , x1 ), dim = - 1 )
787
+
788
+ def apply_rotary_pos_emb (q , k , cos , sin , position_ids , unsqueeze_dim = 1 ):
789
+ """Applies Rotary Position Embedding to the query and key tensors."""
790
+ cos = cos [position_ids ].unsqueeze (unsqueeze_dim )
791
+ sin = sin [position_ids ].unsqueeze (unsqueeze_dim )
792
+ q_embed = (q * cos ) + (rotate_half (q ) * sin )
793
+ k_embed = (k * cos ) + (rotate_half (k ) * sin )
794
+ return q_embed , k_embed
795
+
796
+ def repeat_kv (hidden_states : torch .Tensor , n_rep : int ) -> torch .Tensor :
797
+ """
798
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
799
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
800
+ """
801
+ batch , num_key_value_heads , slen , head_dim = hidden_states .shape
802
+ if n_rep == 1 :
803
+ return hidden_states
804
+ hidden_states = hidden_states [:, :, None , :, :].expand (batch , num_key_value_heads , n_rep , slen , head_dim )
805
+ return hidden_states .reshape (batch , num_key_value_heads * n_rep , slen , head_dim )
768
806
769
807
bsz , q_len , _ = hidden_states .size ()
770
808
771
809
qkv_states = self .wqkv (hidden_states )
772
810
773
- qkv_states = qkv_states .reshape (
774
- qkv_states .shape [0 ], qkv_states .shape [1 ], - 1 , 2 + self .num_key_values_groups , self .head_dim
811
+ qkv_states = rearrange (
812
+ qkv_states ,
813
+ "b q (h gs d) -> b q h gs d" ,
814
+ gs = 2 + self .num_key_value_groups ,
815
+ d = self .head_dim ,
775
816
)
817
+
776
818
query_states = qkv_states [..., : self .num_key_value_groups , :]
777
- query_states = query_states . reshape (query_states . shape [ 0 ], query_states . shape [ 1 ], - 1 , query_states . shape [ - 1 ] )
819
+ query_states = rearrange (query_states , "b q h gs d -> b q (h gs) d" )
778
820
key_states = qkv_states [..., - 2 , :]
779
821
value_states = qkv_states [..., - 1 , :]
780
822
0 commit comments