@@ -951,15 +951,103 @@ def __exit__(self, exc_type, exc_value, traceback):
951
951
block .attention .forward = block .attention ._orig_forward
952
952
953
953
954
+ # Adapted from Phi3Attention.forward
955
+ def _phi3_self_attn_sdpa_forward (
956
+ self ,
957
+ hidden_states : torch .Tensor ,
958
+ attention_mask : Optional [torch .Tensor ] = None ,
959
+ position_ids : Optional [torch .LongTensor ] = None ,
960
+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
961
+ output_attentions : bool = False ,
962
+ use_cache : bool = False ,
963
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
964
+ if output_attentions :
965
+ return self ._orig_forward (
966
+ hidden_states = hidden_states ,
967
+ attention_mask = attention_mask ,
968
+ position_ids = position_ids ,
969
+ past_key_value = past_key_value ,
970
+ output_attentions = output_attentions ,
971
+ use_cache = use_cache ,
972
+ )
973
+
974
+ from transformers .models .llama .modeling_llama import apply_rotary_pos_emb , repeat_kv
975
+
976
+ bsz , q_len , _ = hidden_states .size ()
977
+
978
+ qkv = self .qkv_proj (hidden_states )
979
+ query_pos = self .num_heads * self .head_dim
980
+ query_states = qkv [..., :query_pos ]
981
+ key_states = qkv [..., query_pos : query_pos + self .num_key_value_heads * self .head_dim ]
982
+ value_states = qkv [..., query_pos + self .num_key_value_heads * self .head_dim :]
983
+
984
+ query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
985
+ key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
986
+ value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
987
+ kv_seq_len = key_states .shape [- 2 ]
988
+ if past_key_value is not None :
989
+ kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
990
+ cos , sin = self .rotary_emb (value_states , position_ids , seq_len = kv_seq_len )
991
+
992
+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
993
+
994
+ if past_key_value is not None :
995
+ cache_kwargs = {"sin" : sin , "cos" : cos } # Specific to RoPE models
996
+ key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
997
+
998
+ key_states = repeat_kv (key_states , self .num_key_value_groups )
999
+ value_states = repeat_kv (value_states , self .num_key_value_groups )
1000
+
1001
+ if attention_mask is not None :
1002
+ if attention_mask .size () != (bsz , 1 , q_len , kv_seq_len ):
1003
+ raise ValueError (
1004
+ f"Attention mask should be of size { (bsz , 1 , q_len , kv_seq_len )} , but is { attention_mask .size ()} "
1005
+ )
1006
+
1007
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1008
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
1009
+ if query_states .device .type == "cuda" and attention_mask is not None :
1010
+ query_states = query_states .contiguous ()
1011
+ key_states = key_states .contiguous ()
1012
+ value_states = value_states .contiguous ()
1013
+
1014
+ attn_output = torch .nn .functional .scaled_dot_product_attention (
1015
+ query_states ,
1016
+ key_states ,
1017
+ value_states ,
1018
+ attn_mask = attention_mask ,
1019
+ dropout_p = self .attention_dropout if self .training else 0.0 ,
1020
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
1021
+ is_causal = self .is_causal and attention_mask is None and q_len > 1 ,
1022
+ )
1023
+
1024
+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
1025
+ attn_output = attn_output .view (bsz , q_len , self .hidden_size )
1026
+
1027
+ attn_output = self .o_proj (attn_output )
1028
+
1029
+ return attn_output , None , past_key_value
1030
+
1031
+
954
1032
class Phi3ModelPatcher (DecoderModelPatcher ):
955
1033
def __enter__ (self ):
956
1034
super ().__enter__ ()
957
-
958
1035
# https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113
959
1036
# init inv_freq for torchscript tracing
960
1037
for layer in self ._model .model .layers :
1038
+ if is_torch_version (">=" , "2.1.0" ):
1039
+ orig_self_attn_fwd = layer .self_attn .forward
1040
+ layer .self_attn .forward = types .MethodType (_phi3_self_attn_sdpa_forward , layer .self_attn )
1041
+ layer .self_attn ._orig_forward = orig_self_attn_fwd
1042
+
961
1043
if layer .self_attn .rotary_emb .inv_freq is None :
962
1044
rotary_emb = layer .self_attn .rotary_emb
963
1045
layer .self_attn .rotary_emb .inv_freq = 1.0 / (
964
1046
rotary_emb .base ** (torch .arange (0 , rotary_emb .dim , 2 , dtype = torch .int64 ).float () / rotary_emb .dim )
965
1047
)
1048
+
1049
+ def __exit__ (self , exc_type , exc_value , traceback ):
1050
+ super ().__exit__ (exc_type , exc_value , traceback )
1051
+ for layer in self ._model .model .layers :
1052
+ if hasattr (layer .self_attn , "_orig_forward" ):
1053
+ layer .self_attn .forward = layer .self_attn ._orig_forward
0 commit comments