@@ -673,6 +673,71 @@ def _baichuan13b_atten_forward(
673
673
return attn_output , attn_weights , past_key_value
674
674
675
675
676
+ def _baichuan7b_attn_forward (
677
+ self ,
678
+ hidden_states : torch .Tensor ,
679
+ attention_mask : Optional [torch .Tensor ] = None ,
680
+ position_ids : Optional [torch .LongTensor ] = None ,
681
+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
682
+ output_attentions : bool = False ,
683
+ use_cache : bool = False ,
684
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
685
+ def rotate_half (x ):
686
+ """Rotates half the hidden dims of the input."""
687
+ x1 = x [..., : x .shape [- 1 ] // 2 ]
688
+ x2 = x [..., x .shape [- 1 ] // 2 :]
689
+ return torch .cat ((- x2 , x1 ), dim = - 1 )
690
+
691
+ def apply_rotary_pos_emb (q , k , cos , sin , position_ids ):
692
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
693
+ cos = cos .squeeze (1 ).squeeze (0 ) # [seq_len, dim]
694
+ sin = sin .squeeze (1 ).squeeze (0 ) # [seq_len, dim]
695
+ cos = cos [position_ids ].unsqueeze (1 ) # [bs, 1, seq_len, dim]
696
+ sin = sin [position_ids ].unsqueeze (1 ) # [bs, 1, seq_len, dim]
697
+ q_embed = (q * cos ) + (rotate_half (q ) * sin )
698
+ k_embed = (k * cos ) + (rotate_half (k ) * sin )
699
+ return q_embed , k_embed
700
+
701
+ bsz , q_len , _ = hidden_states .size ()
702
+
703
+ proj = self .W_pack (hidden_states )
704
+ proj = proj .unflatten (- 1 , (3 , self .hidden_size )).unsqueeze (0 ).transpose (0 , - 2 ).squeeze (- 2 )
705
+ query_states = proj [0 ].view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
706
+ key_states = proj [1 ].view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
707
+ value_states = proj [2 ].view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
708
+
709
+ kv_seq_len = key_states .shape [- 2 ]
710
+ if past_key_value is not None :
711
+ kv_seq_len += past_key_value [0 ].shape [- 2 ]
712
+ cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
713
+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
714
+ # [bsz, nh, t, hd]
715
+
716
+ if past_key_value is not None :
717
+ # reuse k, v, self_attention
718
+ key_states = torch .cat ([past_key_value [0 ], key_states ], dim = 2 )
719
+ value_states = torch .cat ([past_key_value [1 ], value_states ], dim = 2 )
720
+
721
+ past_key_value = (key_states , value_states ) if use_cache else None
722
+ if not output_attentions :
723
+ attn_weights = None
724
+ attn_output = F .scaled_dot_product_attention (
725
+ query_states , key_states , value_states , attn_mask = attention_mask , scale = 1 / math .sqrt (self .head_dim )
726
+ )
727
+ else :
728
+ attn_weights = torch .matmul (query_states , key_states .transpose (2 , 3 )) / math .sqrt (self .head_dim )
729
+ if attention_mask is not None :
730
+ attn_weights = attn_weights + attention_mask
731
+ attn_weights = torch .nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query_states .dtype )
732
+ attn_output = torch .matmul (attn_weights , value_states )
733
+
734
+ attn_output = attn_output .transpose (1 , 2 )
735
+ attn_output = attn_output .reshape (bsz , q_len , self .hidden_size )
736
+ attn_output = self .o_proj (attn_output )
737
+
738
+ return attn_output , attn_weights , past_key_value
739
+
740
+
676
741
class BaichuanModelPatcher (DecoderModelPatcher ):
677
742
def __init__ (
678
743
self ,
@@ -720,13 +785,18 @@ def forward(
720
785
for layer in self ._model .model .layers :
721
786
layer .self_attn ._orig_forward = layer .self_attn .forward
722
787
layer .self_attn .forward = types .MethodType (_baichuan13b_atten_forward , layer .self_attn )
788
+ else :
789
+ for layer in self ._model .model .layers :
790
+ layer .self_attn ._orig_forward = layer .self_attn .forward
791
+ layer .self_attn .forward = types .MethodType (_baichuan7b_attn_forward , layer .self_attn )
723
792
724
793
def __exit__ (self , exc_type , exc_value , traceback ):
725
794
super ().__exit__ (exc_type , exc_value , traceback )
726
795
if hasattr (self ._model , "_orig_forward" ):
727
796
self ._model .forward = self ._model ._orig_forward
728
797
729
- for layer in self ._model .model .layers :
798
+ for layer in self ._model .model .layers :
799
+ if hasattr (layer .self_attn , "_orig_forward" ):
730
800
layer .self_attn .forward = layer .self_attn ._orig_forward
731
801
732
802
0 commit comments