@@ -951,15 +951,107 @@ def __exit__(self, exc_type, exc_value, traceback):
951
951
block .attention .forward = block .attention ._orig_forward
952
952
953
953
954
+ # Adapted from https://github.com/huggingface/transformers/blob/ccdabc5642bf84849af93f591e207dc625c8e1e1/src/transformers/models/phi3/modeling_phi3.py#L426
955
+ def _phi3_self_attn_sdpa_forward (
956
+ self ,
957
+ hidden_states : torch .Tensor ,
958
+ attention_mask : Optional [torch .Tensor ] = None ,
959
+ position_ids : Optional [torch .LongTensor ] = None ,
960
+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
961
+ output_attentions : bool = False ,
962
+ use_cache : bool = False ,
963
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
964
+ if output_attentions :
965
+ return self ._orig_forward (
966
+ hidden_states = hidden_states ,
967
+ attention_mask = attention_mask ,
968
+ position_ids = position_ids ,
969
+ past_key_value = past_key_value ,
970
+ output_attentions = output_attentions ,
971
+ use_cache = use_cache ,
972
+ )
973
+
974
+ # TO DO: remove llama imports when transformers with phi3 support will be released
975
+ try :
976
+ from transformers .models .phi3 .modelling_phi3 import apply_rotary_pos_emb , repeat_kv
977
+ except ImportError :
978
+ from transformers .models .llama .modeling_llama import apply_rotary_pos_emb , repeat_kv
979
+
980
+ bsz , q_len , _ = hidden_states .size ()
981
+
982
+ qkv = self .qkv_proj (hidden_states )
983
+ query_pos = self .num_heads * self .head_dim
984
+ query_states = qkv [..., :query_pos ]
985
+ key_states = qkv [..., query_pos : query_pos + self .num_key_value_heads * self .head_dim ]
986
+ value_states = qkv [..., query_pos + self .num_key_value_heads * self .head_dim :]
987
+
988
+ query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
989
+ key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
990
+ value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
991
+ kv_seq_len = key_states .shape [- 2 ]
992
+ if past_key_value is not None :
993
+ kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
994
+ cos , sin = self .rotary_emb (value_states , position_ids , seq_len = kv_seq_len )
995
+
996
+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
997
+
998
+ if past_key_value is not None :
999
+ cache_kwargs = {"sin" : sin , "cos" : cos } # Specific to RoPE models
1000
+ key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
1001
+
1002
+ key_states = repeat_kv (key_states , self .num_key_value_groups )
1003
+ value_states = repeat_kv (value_states , self .num_key_value_groups )
1004
+
1005
+ if attention_mask is not None :
1006
+ if attention_mask .size () != (bsz , 1 , q_len , kv_seq_len ):
1007
+ raise ValueError (
1008
+ f"Attention mask should be of size { (bsz , 1 , q_len , kv_seq_len )} , but is { attention_mask .size ()} "
1009
+ )
1010
+
1011
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1012
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
1013
+ if query_states .device .type == "cuda" and attention_mask is not None :
1014
+ query_states = query_states .contiguous ()
1015
+ key_states = key_states .contiguous ()
1016
+ value_states = value_states .contiguous ()
1017
+
1018
+ attn_output = torch .nn .functional .scaled_dot_product_attention (
1019
+ query_states ,
1020
+ key_states ,
1021
+ value_states ,
1022
+ attn_mask = attention_mask ,
1023
+ dropout_p = self .attention_dropout if self .training else 0.0 ,
1024
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
1025
+ is_causal = self .is_causal and attention_mask is None and q_len > 1 ,
1026
+ )
1027
+
1028
+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
1029
+ attn_output = attn_output .view (bsz , q_len , self .hidden_size )
1030
+
1031
+ attn_output = self .o_proj (attn_output )
1032
+
1033
+ return attn_output , None , past_key_value
1034
+
1035
+
954
1036
class Phi3ModelPatcher (DecoderModelPatcher ):
955
1037
def __enter__ (self ):
956
1038
super ().__enter__ ()
957
-
958
1039
# https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113
959
1040
# init inv_freq for torchscript tracing
960
1041
for layer in self ._model .model .layers :
1042
+ if is_torch_version (">=" , "2.1.0" ):
1043
+ orig_self_attn_fwd = layer .self_attn .forward
1044
+ layer .self_attn .forward = types .MethodType (_phi3_self_attn_sdpa_forward , layer .self_attn )
1045
+ layer .self_attn ._orig_forward = orig_self_attn_fwd
1046
+
961
1047
if layer .self_attn .rotary_emb .inv_freq is None :
962
1048
rotary_emb = layer .self_attn .rotary_emb
963
1049
layer .self_attn .rotary_emb .inv_freq = 1.0 / (
964
1050
rotary_emb .base ** (torch .arange (0 , rotary_emb .dim , 2 , dtype = torch .int64 ).float () / rotary_emb .dim )
965
1051
)
1052
+
1053
+ def __exit__ (self , exc_type , exc_value , traceback ):
1054
+ super ().__exit__ (exc_type , exc_value , traceback )
1055
+ for layer in self ._model .model .layers :
1056
+ if hasattr (layer .self_attn , "_orig_forward" ):
1057
+ layer .self_attn .forward = layer .self_attn ._orig_forward
0 commit comments