@@ -603,6 +603,125 @@ def _gpt2_block_forward(
603
603
return outputs # hidden_states, present, (attentions, cross_attentions)
604
604
605
605
606
+ # Adapted from https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/qwen2/modeling_qwen2.py#L499
607
+ def _qwen2_model_forward (
608
+ self ,
609
+ input_ids : torch .LongTensor = None ,
610
+ attention_mask : Optional [torch .Tensor ] = None ,
611
+ position_ids : Optional [torch .LongTensor ] = None ,
612
+ past_key_values : Optional [Cache ] = None ,
613
+ inputs_embeds : Optional [torch .FloatTensor ] = None ,
614
+ use_cache : Optional [bool ] = None ,
615
+ output_attentions : Optional [bool ] = None ,
616
+ output_hidden_states : Optional [bool ] = None ,
617
+ return_dict : Optional [bool ] = None ,
618
+ cache_position : Optional [torch .LongTensor ] = None ,
619
+ ** kwargs ,
620
+ ) -> Union [Tuple , BaseModelOutputWithPast ]:
621
+ output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
622
+ output_hidden_states = (
623
+ output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
624
+ )
625
+ use_cache = use_cache if use_cache is not None else self .config .use_cache
626
+ return_dict = return_dict if return_dict is not None else self .config .use_return_dict
627
+
628
+ if (input_ids is None ) ^ (inputs_embeds is not None ):
629
+ raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
630
+
631
+ if self .gradient_checkpointing and self .training and use_cache :
632
+ logger .warning_once ("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." )
633
+ use_cache = False
634
+
635
+ if inputs_embeds is None :
636
+ inputs_embeds = self .embed_tokens (input_ids )
637
+
638
+ batch_size , seq_length = inputs_embeds .shape [:2 ]
639
+
640
+ past_key_values_length = past_key_values .get_seq_length () if past_key_values is not None else 0
641
+ if cache_position is None :
642
+ cache_position = torch .arange (
643
+ past_key_values_length , past_key_values_length + inputs_embeds .shape [1 ], device = inputs_embeds .device
644
+ )
645
+
646
+ if position_ids is None :
647
+ device = input_ids .device if input_ids is not None else inputs_embeds .device
648
+ position_ids = torch .arange (
649
+ past_key_values_length , seq_length + past_key_values_length , dtype = torch .long , device = device
650
+ )
651
+ position_ids = position_ids .unsqueeze (0 ).repeat_interleave (input_ids .shape [0 ], 0 )
652
+
653
+ causal_mask = self ._update_causal_mask (
654
+ attention_mask , inputs_embeds , cache_position , past_key_values , output_attentions
655
+ )
656
+
657
+ hidden_states = inputs_embeds
658
+
659
+ # create position embeddings to be shared across the decoder layers
660
+ position_embeddings = self .rotary_emb (hidden_states , position_ids )
661
+
662
+ input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
663
+
664
+ if past_key_values_length == 0 and past_key_values is not None :
665
+ # first token, remove the padding from hidden_states, varlen do not accept attention mask
666
+ hidden_states_copy = hidden_states
667
+ index = attention_mask .view (- 1 ) != 0
668
+ hidden_states = (hidden_states .view (- 1 , hidden_states .shape [- 1 ]))[index ]
669
+ cos = position_embeddings [0 ]
670
+ sin = position_embeddings [1 ]
671
+ cos = (cos .reshape (- 1 , cos .shape [- 1 ]))[index ]
672
+ sin = (sin .reshape (- 1 , sin .shape [- 1 ]))[index ]
673
+ position_embeddings = (cos .unsqueeze (1 ), sin .unsqueeze (1 ))
674
+ else :
675
+ hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
676
+
677
+ if past_key_values is None :
678
+ attention_mask = causal_mask
679
+
680
+ # decoder layers
681
+ all_hidden_states = () if output_hidden_states else None
682
+ all_self_attns = () if output_attentions else None
683
+
684
+ for decoder_layer in self .layers [: self .config .num_hidden_layers ]:
685
+ if output_hidden_states :
686
+ all_hidden_states += (hidden_states ,)
687
+
688
+ layer_outputs = decoder_layer (
689
+ hidden_states ,
690
+ attention_mask = attention_mask ,
691
+ position_ids = position_ids ,
692
+ past_key_value = past_key_values ,
693
+ output_attentions = output_attentions ,
694
+ use_cache = use_cache ,
695
+ cache_position = cache_position ,
696
+ position_embeddings = position_embeddings ,
697
+ input_lens = input_lens ,
698
+ ** kwargs ,
699
+ )
700
+
701
+ hidden_states = layer_outputs [0 ]
702
+
703
+ if output_attentions :
704
+ all_self_attns += (layer_outputs [1 ],)
705
+
706
+ hidden_states = self .norm (hidden_states )
707
+
708
+ if hidden_states .shape [0 ] != batch_size * seq_length :
709
+ (hidden_states_copy .view (- 1 , hidden_states .shape [- 1 ]))[attention_mask .view (- 1 ) != 0 ] = hidden_states
710
+ hidden_states = hidden_states_copy
711
+ hidden_states = hidden_states .view (batch_size , - 1 , hidden_states .shape [- 1 ])
712
+ # add hidden states from the last decoder layer
713
+ if output_hidden_states :
714
+ all_hidden_states += (hidden_states ,)
715
+
716
+ output = BaseModelOutputWithPast (
717
+ last_hidden_state = hidden_states ,
718
+ past_key_values = past_key_values if use_cache else None ,
719
+ hidden_states = all_hidden_states ,
720
+ attentions = all_self_attns ,
721
+ )
722
+ return output if return_dict else output .to_tuple ()
723
+
724
+
606
725
class _IPEXAttention (nn .Module ):
607
726
def __init__ (self , module , config ) -> None :
608
727
super ().__init__ ()
@@ -618,8 +737,10 @@ def __init__(self, module, config) -> None:
618
737
def qkv_gemm (self , hidden_states ):
619
738
raise NotImplementedError ("Need to implement in specific model class" )
620
739
621
- def rope (self , * args , ** kwargs ):
622
- raise NotImplementedError ("Need to implement in specific model class" )
740
+ def rope (self , query , key , ** kwargs ):
741
+ position_embeddings = kwargs .pop ("position_embeddings" , None )
742
+ rotary_embedding (query , key , position_embeddings [1 ], position_embeddings [0 ], query .size (- 1 ), True )
743
+ return query , key
623
744
624
745
def postprocess_attention_output (self , attn_output ):
625
746
if self .use_sdpa :
@@ -748,13 +869,13 @@ class _IPEXLlamaAttention(_IPEXAttention):
748
869
def __init__ (self , module , config ) -> None :
749
870
super ().__init__ (module , config )
750
871
concat_weight = torch .concat ([self .q_proj .weight , self .k_proj .weight , self .v_proj .weight ]).contiguous ()
751
- bias_list = [bias for bias in [self .q_proj .bias , self .k_proj .bias , self .v_proj .bias ] if bias ]
872
+ bias_list = [bias for bias in [self .q_proj .bias , self .k_proj .bias , self .v_proj .bias ] if bias is not None ]
752
873
use_bias = bias_list != []
753
874
self .concat_qkv = nn .Linear (concat_weight .shape [1 ], concat_weight .shape [0 ], bias = use_bias )
754
875
self .concat_qkv .weight = nn .Parameter (concat_weight )
755
876
if use_bias :
756
877
concat_bias = torch .concat (bias_list , 0 ).contiguous ()
757
- self .concat_linear .bias = nn .Parameter (concat_bias )
878
+ self .concat_qkv .bias = nn .Parameter (concat_bias )
758
879
self .q_slice = self .q_proj .weight .shape [0 ]
759
880
self .k_slice = self .q_slice + self .k_proj .weight .shape [0 ]
760
881
self .v_slice = self .k_slice + self .v_proj .weight .shape [0 ]
@@ -774,11 +895,6 @@ def qkv_gemm(self, hidden_states):
774
895
775
896
return query , key , value
776
897
777
- def rope (self , query , key , ** kwargs ):
778
- position_embeddings = kwargs .pop ("position_embeddings" , None )
779
- rotary_embedding (query , key , position_embeddings [1 ], position_embeddings [0 ], query .size (- 1 ), True )
780
- return query , key
781
-
782
898
783
899
class _IPEXFalconAttention (_IPEXAttention ):
784
900
def __init__ (self , module , config ):
@@ -801,11 +917,6 @@ def qkv_gemm(self, hidden_states):
801
917
value = qkv_out [:, self .k_slice :].view (- 1 , self .num_key_value_heads , self .head_dim )
802
918
return query , key , value
803
919
804
- def rope (self , query , key , ** kwargs ):
805
- position_embeddings = kwargs .pop ("position_embeddings" , None )
806
- rotary_embedding (query , key , position_embeddings [1 ], position_embeddings [0 ], query .size (- 1 ), True )
807
- return query , key
808
-
809
920
810
921
class _IPEXGPT2Attention (_IPEXAttention ):
811
922
def __init__ (self , module , config ) -> None :
@@ -1006,6 +1117,12 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
1006
1117
return outputs
1007
1118
1008
1119
1120
+ # Currently can just apply llama decoder layer.
1121
+ class _IPEXQwen2DecoderLayer (_IPEXLlamaDecoderLayer ):
1122
+ def __init__ (self , * args , ** kwargs ):
1123
+ super ().__init__ (* args , ** kwargs )
1124
+
1125
+
1009
1126
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524
1010
1127
class _IPEXIntermediate (nn .Module ):
1011
1128
def __init__ (self , module , config ):
0 commit comments