16
16
import math
17
17
import types
18
18
from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Union
19
+ import inspect
19
20
20
21
import torch
21
22
import torch .nn .functional as F
@@ -601,6 +602,46 @@ def __exit__(self, exc_type, exc_value, traceback):
601
602
self ._model .config .fp16 = self .original_fp16
602
603
603
604
605
+ def _baichuan13b_atten_forward (
606
+ self ,
607
+ hidden_states : torch .Tensor ,
608
+ attention_mask : Optional [torch .Tensor ] = None ,
609
+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
610
+ output_attentions : bool = False ,
611
+ use_cache : bool = True ,
612
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
613
+ bsz , q_len , _ = hidden_states .size ()
614
+
615
+ proj = self .W_pack (hidden_states )
616
+ proj = proj .unflatten (- 1 , (3 , self .hidden_size )).unsqueeze (0 ).transpose (0 , - 2 ).squeeze (- 2 )
617
+ query_states = proj [0 ].view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
618
+ key_states = proj [1 ].view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
619
+ value_states = proj [2 ].view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
620
+
621
+ kv_seq_len = key_states .shape [- 2 ]
622
+ if past_key_value is not None :
623
+ kv_seq_len += past_key_value [0 ].shape [- 2 ]
624
+
625
+ if past_key_value is not None :
626
+ # reuse k, v, self_attention
627
+ if attention_mask is not None :
628
+ attention_mask = attention_mask [:, :, - key_states .shape [- 2 ] :, :]
629
+ key_states = torch .cat ([past_key_value [0 ], key_states ], dim = 2 )
630
+ value_states = torch .cat ([past_key_value [1 ], value_states ], dim = 2 )
631
+
632
+ past_key_value = (key_states , value_states ) if use_cache else None
633
+ attn_output = F .scaled_dot_product_attention (query_states , key_states , value_states , attn_mask = attention_mask )
634
+ attn_output = attn_output .transpose (1 , 2 )
635
+ attn_weights = None
636
+ attn_output = attn_output .reshape (bsz , q_len , self .hidden_size )
637
+ attn_output = self .o_proj (attn_output )
638
+
639
+ if not output_attentions :
640
+ attn_weights = None
641
+
642
+ return attn_output , attn_weights , past_key_value
643
+
644
+
604
645
class BaichuanModelPatcher (DecoderModelPatcher ):
605
646
def __init__ (
606
647
self ,
@@ -613,6 +654,50 @@ def __init__(
613
654
if hasattr (self ._model .lm_head , "first_flag" ):
614
655
self ._model (torch .ones ((1 , 10 ), dtype = torch .int64 ), torch .ones ((1 , 10 ), dtype = torch .int64 ))
615
656
657
+ def __enter__ (self ):
658
+ super ().__enter__ ()
659
+ # override signature to have position_ids
660
+ if "position_ids" not in inspect .signature (self ._model .forward ).parameters :
661
+ self ._model ._orig_forward = self ._model .forward
662
+
663
+ def forward (
664
+ self ,
665
+ input_ids : torch .LongTensor = None ,
666
+ attention_mask : Optional [torch .Tensor ] = None ,
667
+ past_key_values : Optional [Tuple [torch .FloatTensor ]] = None ,
668
+ inputs_embeds : Optional [torch .FloatTensor ] = None ,
669
+ labels : Optional [torch .LongTensor ] = None ,
670
+ use_cache : Optional [bool ] = None ,
671
+ output_attentions : Optional [bool ] = False ,
672
+ output_hidden_states : Optional [bool ] = False ,
673
+ return_dict : Optional [bool ] = True ,
674
+ position_ids : Optional [torch .LongTensor ] = None ,
675
+ ):
676
+ return self ._orig_forward (
677
+ input_ids = input_ids ,
678
+ attention_mask = attention_mask ,
679
+ past_key_values = past_key_values ,
680
+ inputs_embeds = inputs_embeds ,
681
+ labels = labels ,
682
+ use_cache = past_key_values is not None ,
683
+ output_attentions = output_attentions ,
684
+ output_hidden_states = output_hidden_states ,
685
+ return_dict = self .config .return_dict ,
686
+ )
687
+
688
+ self ._model .forward = types .MethodType (forward , self ._model )
689
+ for layer in self ._model .model .layers :
690
+ layer .self_attn ._orig_forward = layer .self_attn .forward
691
+ layer .self_attn .forward = types .MethodType (_baichuan13b_atten_forward , layer .self_attn )
692
+
693
+ def __exit__ (self , exc_type , exc_value , traceback ):
694
+ super ().__exit__ (exc_type , exc_value , traceback )
695
+ if hasattr (self ._model , "_orig_forward" ):
696
+ self ._model .forward = self ._model ._orig_forward
697
+
698
+ for layer in self ._model .model .layers :
699
+ layer .self_attn .forward = layer .self_attn ._orig_forward
700
+
616
701
617
702
def _mpt_attention_forward (
618
703
self ,
0 commit comments