@@ -640,11 +640,25 @@ def _baichuan13b_atten_forward(
640
640
attention_mask = attention_mask [:, :, - key_states .shape [- 2 ] :, :]
641
641
key_states = torch .cat ([past_key_value [0 ], key_states ], dim = 2 )
642
642
value_states = torch .cat ([past_key_value [1 ], value_states ], dim = 2 )
643
+ if not output_attentions :
644
+ past_key_value = (key_states , value_states ) if use_cache else None
645
+ attn_output = F .scaled_dot_product_attention (query_states , key_states , value_states , attn_mask = attention_mask )
646
+ attn_weights = None
647
+ else :
648
+ attn_weights = torch .matmul (query_states , key_states .transpose (2 , 3 )) / math .sqrt (self .head_dim )
649
+
650
+ if attention_mask is not None :
651
+ if q_len == 1 : # inference with cache
652
+ if len (attention_mask .size ()) == 4 :
653
+ attention_mask = attention_mask [:, :, - 1 :, :]
654
+ else :
655
+ attention_mask = attention_mask [:, - 1 :, :]
656
+ attn_weights = attn_weights + attention_mask
657
+ attn_weights = torch .max (attn_weights , torch .tensor (torch .finfo (attn_weights .dtype ).min ))
658
+ attn_weights = torch .nn .functional .softmax (attn_weights , dim = - 1 )
659
+ attn_output = torch .matmul (attn_weights , value_states )
643
660
644
- past_key_value = (key_states , value_states ) if use_cache else None
645
- attn_output = F .scaled_dot_product_attention (query_states , key_states , value_states , attn_mask = attention_mask )
646
661
attn_output = attn_output .transpose (1 , 2 )
647
- attn_weights = None
648
662
attn_output = attn_output .reshape (bsz , q_len , self .hidden_size )
649
663
attn_output = self .o_proj (attn_output )
650
664
@@ -708,7 +722,7 @@ def __exit__(self, exc_type, exc_value, traceback):
708
722
layer .self_attn .forward = layer .self_attn ._orig_forward
709
723
710
724
711
- def _mpt_attention_forward (
725
+ def _mpt_sdpa_attention_forward (
712
726
self ,
713
727
hidden_states : torch .Tensor ,
714
728
position_bias : torch .Tensor ,
@@ -759,18 +773,73 @@ def _mpt_attention_forward(
759
773
return attn_output , None , past_key_value
760
774
761
775
776
+ def _mpt_block_forward (
777
+ self ,
778
+ hidden_states : torch .Tensor ,
779
+ position_bias : torch .Tensor ,
780
+ attention_mask : torch .Tensor ,
781
+ layer_past : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
782
+ use_cache : bool = False ,
783
+ output_attentions : bool = False ,
784
+ ):
785
+ # hidden_states: [batch_size, seq_length, hidden_size]
786
+ # Layer norm at the beginning of the transformer layer.
787
+ layernorm_output = self .norm_1 (hidden_states )
788
+
789
+ residual = hidden_states
790
+
791
+ if not output_attentions :
792
+ # Self attention.
793
+ attn_outputs , attn_weights , past_key_value = self .attn (
794
+ layernorm_output ,
795
+ position_bias = position_bias ,
796
+ attention_mask = attention_mask ,
797
+ past_key_value = layer_past ,
798
+ )
799
+ else :
800
+ attn_outputs , attn_weights , past_key_value = self .attn ._orig_forward (
801
+ layernorm_output ,
802
+ position_bias = position_bias ,
803
+ attention_mask = attention_mask ,
804
+ past_key_value = layer_past ,
805
+ )
806
+
807
+ hidden_states = self .resid_attn_dropout (attn_outputs ) + residual
808
+
809
+ layernorm_output = self .norm_2 (hidden_states )
810
+
811
+ # Get residual
812
+ residual = hidden_states
813
+
814
+ # MLP.
815
+ output = self .ffn (layernorm_output , residual )
816
+ outputs = (output ,)
817
+
818
+ if use_cache :
819
+ outputs += (past_key_value ,)
820
+
821
+ if output_attentions :
822
+ outputs += (attn_weights ,)
823
+
824
+ return outputs
825
+
826
+
762
827
class MPTModelPatcher (DecoderModelPatcher ):
763
828
def __enter__ (self ):
764
829
super ().__enter__ ()
765
830
766
831
if is_torch_version (">=" , "2.1.0" ):
767
832
for block in self ._model .transformer .blocks :
833
+ block ._orig_forward = block .forward
834
+ block .forward = types .MethodType (_mpt_block_forward , block )
768
835
block .attn ._orig_forward = block .attn .forward
769
- block .attn .forward = types .MethodType (_mpt_attention_forward , block .attn )
836
+ block .attn .forward = types .MethodType (_mpt_sdpa_attention_forward , block .attn )
770
837
771
838
def __exit__ (self , exc_type , exc_value , traceback ):
772
839
super ().__exit__ (exc_type , exc_value , traceback )
773
840
for block in self ._model .transformer .blocks :
841
+ if hasattr (block , "_orig_forward" ):
842
+ block .forward = block ._orig_forward
774
843
if hasattr (block .attn , "_orig_forward" ):
775
844
block .attn .forward = block .attn ._orig_forward
776
845
@@ -848,17 +917,21 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
848
917
849
918
key_states = repeat_kv (key_states , self .num_key_value_groups )
850
919
value_states = repeat_kv (value_states , self .num_key_value_groups )
920
+ if not output_attentions :
921
+ attn_output = torch .nn .functional .scaled_dot_product_attention (
922
+ query_states , key_states , value_states , attention_mask , scale = (1 / math .sqrt (self .head_dim ))
923
+ )
924
+ attn_weights = None
925
+ else :
926
+ attn_weights = torch .matmul (query_states , key_states .transpose (2 , 3 )) / math .sqrt (self .head_dim )
927
+ attn_weights = torch .nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query_states .dtype )
928
+ attn_output = torch .matmul (attn_weights , value_states )
851
929
852
- attn_output = torch .nn .functional .scaled_dot_product_attention (
853
- query_states , key_states , value_states , attention_mask , scale = (1 / math .sqrt (self .head_dim ))
854
- )
855
930
attn_output = attn_output .transpose (1 , 2 ).contiguous ()
856
931
attn_output = attn_output .reshape (bsz , q_len , self .hidden_size )
857
932
858
933
attn_output = self .wo (attn_output )
859
934
860
- attn_weights = None
861
-
862
935
return attn_output , attn_weights , past_key_value
863
936
864
937
0 commit comments