@@ -558,78 +558,6 @@ def _gpt2_model_forward(
558
558
)
559
559
560
560
561
- # To pass input_lens, adapted from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt2/modeling_gpt2.py#L602
562
- def _gpt2_block_forward (
563
- self ,
564
- hidden_states : Optional [Tuple [torch .FloatTensor ]],
565
- layer_past : Optional [Tuple [torch .Tensor ]] = None ,
566
- attention_mask : Optional [torch .FloatTensor ] = None ,
567
- head_mask : Optional [torch .FloatTensor ] = None ,
568
- encoder_hidden_states : Optional [torch .Tensor ] = None ,
569
- encoder_attention_mask : Optional [torch .FloatTensor ] = None ,
570
- use_cache : Optional [bool ] = False ,
571
- output_attentions : Optional [bool ] = False ,
572
- ** kwargs ,
573
- ) -> Union [Tuple [torch .Tensor ], Optional [Tuple [torch .Tensor , Tuple [torch .FloatTensor , ...]]]]:
574
- residual = hidden_states
575
- hidden_states = self .ln_1 (hidden_states )
576
- attn_outputs = self .attn (
577
- hidden_states ,
578
- layer_past = layer_past ,
579
- attention_mask = attention_mask ,
580
- head_mask = head_mask ,
581
- use_cache = use_cache ,
582
- output_attentions = output_attentions ,
583
- ** kwargs ,
584
- )
585
- attn_output = attn_outputs [0 ] # output_attn: a, present, (attentions)
586
- outputs = attn_outputs [1 :]
587
- # residual connection
588
- if hasattr (self .attn , "linear_add" ):
589
- hidden_states = self .attn .linear_add (attn_output , residual )
590
- else :
591
- hidden_states = attn_output + residual
592
-
593
- if encoder_hidden_states is not None :
594
- # add one self-attention block for cross-attention
595
- if not hasattr (self , "crossattention" ):
596
- raise ValueError (
597
- f"If `encoder_hidden_states` are passed, { self } has to be instantiated with "
598
- "cross-attention layers by setting `config.add_cross_attention=True`"
599
- )
600
- residual = hidden_states
601
- hidden_states = self .ln_cross_attn (hidden_states )
602
- cross_attn_outputs = self .crossattention (
603
- hidden_states ,
604
- attention_mask = attention_mask ,
605
- head_mask = head_mask ,
606
- encoder_hidden_states = encoder_hidden_states ,
607
- encoder_attention_mask = encoder_attention_mask ,
608
- output_attentions = output_attentions ,
609
- ** kwargs ,
610
- )
611
- attn_output = cross_attn_outputs [0 ]
612
- # residual connection
613
- hidden_states = residual + attn_output
614
- outputs = outputs + cross_attn_outputs [2 :] # add cross attentions if we output attention weights
615
-
616
- residual = hidden_states
617
- hidden_states = self .ln_2 (hidden_states )
618
- feed_forward_hidden_states = self .mlp (hidden_states )
619
- # residual connection
620
- if hasattr (self .mlp , "linear_add" ):
621
- hidden_states = self .mlp .linear_add (feed_forward_hidden_states , residual )
622
- else :
623
- hidden_states = residual + feed_forward_hidden_states
624
-
625
- if use_cache :
626
- outputs = (hidden_states ,) + outputs
627
- else :
628
- outputs = (hidden_states ,) + outputs [1 :]
629
-
630
- return outputs # hidden_states, present, (attentions, cross_attentions)
631
-
632
-
633
561
class _IPEXAttention (nn .Module ):
634
562
def __init__ (self , module , device , config ) -> None :
635
563
super ().__init__ ()
@@ -844,26 +772,27 @@ class _IPEXGPT2Attention(_IPEXAttention):
844
772
def __init__ (self , module , device , config ) -> None :
845
773
self .num_key_value_heads = config .num_key_value_heads
846
774
super ().__init__ (module , device , config )
847
- if getattr (config , "quantization_config" , None ):
848
- _remove_hooks_for_ipex (self , True )
849
-
850
775
_setattr_from_module (self , module )
851
- self .c_attn_linear = nn .Linear (self .c_attn .weight .shape [0 ], self .c_attn .weight .shape [1 ])
852
- self .c_attn_linear .weight = nn .Parameter (self .c_attn .weight .t ())
853
- self .c_attn_linear .bias = self .c_attn .bias
854
- self .c_proj_linear = nn .Linear (self .c_proj .weight .shape [0 ], self .c_proj .weight .shape [1 ])
855
- self .c_proj_linear .weight = nn .Parameter (self .c_proj .weight .t ())
856
- self .c_proj_linear .bias = self .c_proj .bias
857
- if self .module_device .type == "cpu" :
858
- if self .c_proj_linear not in ["LinearAllreduce" ]:
859
- self .linear_add = LinearAdd (self .c_proj_linear )
860
-
861
- elif self .module_device .type == "xpu" :
862
- if self .c_proj_linear not in ["LinearAllreduce" ]:
863
- self .linear_add = XPULinearAdd (self .c_proj_linear )
776
+ if getattr (config , "quantization_config" , None ) is None :
777
+ self .c_attn_linear = nn .Linear (self .c_attn .weight .shape [0 ], self .c_attn .weight .shape [1 ])
778
+ self .c_attn_linear .weight = nn .Parameter (self .c_attn .weight .t ())
779
+ self .c_attn_linear .bias = self .c_attn .bias
780
+ self .c_proj_linear = nn .Linear (self .c_proj .weight .shape [0 ], self .c_proj .weight .shape [1 ])
781
+ self .c_proj_linear .weight = nn .Parameter (self .c_proj .weight .t ())
782
+ self .c_proj_linear .bias = self .c_proj .bias
783
+ if self .module_device .type == "cpu" :
784
+ if self .c_proj_linear not in ["LinearAllreduce" ]:
785
+ self .linear_add = LinearAdd (self .c_proj_linear )
786
+
787
+ elif self .module_device .type == "xpu" :
788
+ if self .c_proj_linear not in ["LinearAllreduce" ]:
789
+ self .linear_add = XPULinearAdd (self .c_proj_linear )
864
790
865
791
def qkv_gemm (self , hidden_states ):
866
- query , key , value = self .c_attn_linear (hidden_states ).split (self .split_size , dim = - 1 )
792
+ if hasattr (self , "c_attn_linear" ):
793
+ query , key , value = self .c_attn_linear (hidden_states ).split (self .split_size , dim = - 1 )
794
+ else :
795
+ query , key , value = self .c_attn (hidden_states ).split (self .split_size , dim = - 1 )
867
796
query = query .view (- 1 , self .num_heads , self .head_dim )
868
797
key = key .view (- 1 , self .num_heads , self .head_dim )
869
798
value = value .view (- 1 , self .num_heads , self .head_dim )
@@ -951,27 +880,29 @@ def forward(
951
880
952
881
953
882
class _IPEXGPT2MLP (nn .Module ):
954
- def __init__ (self , module , config ) -> None :
883
+ def __init__ (self , module , device , config ) -> None :
955
884
super ().__init__ ()
956
885
_setattr_from_module (self , module )
957
886
self .config = config
958
- self .module_device = next (module .parameters ()).device
959
- self .c_fc_linear = nn .Linear (self .c_fc .weight .shape [0 ], self .c_fc .weight .shape [1 ])
960
- self .c_fc_linear .weight = nn .Parameter (self .c_fc .weight .t ())
961
- self .c_fc_linear .bias = self .c_fc .bias
962
- self .c_proj_linear = nn .Linear (self .c_proj .weight .shape [0 ], self .c_proj .weight .shape [1 ])
963
- self .c_proj_linear .weight = nn .Parameter (self .c_proj .weight .t ())
964
- self .c_proj_linear .bias = self .c_proj .bias
965
- if self .module_device .type == "cpu" :
966
- self .linear_new_gelu = LinearNewGelu (self .c_fc_linear )
967
-
968
- if self .module_device .type == "cpu" :
969
- if self .c_proj_linear not in ["LinearAllreduce" ]:
970
- self .linear_add = LinearAdd (self .c_proj_linear )
971
-
972
- elif self .module_device .type == "xpu" :
973
- if self .c_proj_linear not in ["LinearAllreduce" ]:
974
- self .linear_add = XPULinearAdd (self .c_proj_linear )
887
+ self .module_device = device
888
+
889
+ if getattr (config , "quantization_config" , None ) is None :
890
+ self .c_fc_linear = nn .Linear (self .c_fc .weight .shape [0 ], self .c_fc .weight .shape [1 ])
891
+ self .c_fc_linear .weight = nn .Parameter (self .c_fc .weight .t ())
892
+ self .c_fc_linear .bias = self .c_fc .bias
893
+ self .c_proj_linear = nn .Linear (self .c_proj .weight .shape [0 ], self .c_proj .weight .shape [1 ])
894
+ self .c_proj_linear .weight = nn .Parameter (self .c_proj .weight .t ())
895
+ self .c_proj_linear .bias = self .c_proj .bias
896
+ if self .module_device .type == "cpu" :
897
+ self .linear_new_gelu = LinearNewGelu (self .c_fc_linear )
898
+
899
+ if self .module_device .type == "cpu" :
900
+ if self .c_proj_linear not in ["LinearAllreduce" ]:
901
+ self .linear_add = LinearAdd (self .c_proj_linear )
902
+
903
+ elif self .module_device .type == "xpu" :
904
+ if self .c_proj_linear not in ["LinearAllreduce" ]:
905
+ self .linear_add = XPULinearAdd (self .c_proj_linear )
975
906
976
907
def forward (self , hidden_states : Optional [Tuple [torch .FloatTensor ]]) -> torch .FloatTensor :
977
908
if hasattr (self , "linear_new_gelu" ):
@@ -1048,6 +979,86 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
1048
979
return outputs
1049
980
1050
981
982
+ class _IPEXGPT2Block (nn .Module ):
983
+ def __init__ (self , module , device , config ):
984
+ super ().__init__ ()
985
+ _setattr_from_module (self , module )
986
+ self .attn = _IPEXGPT2Attention (module .attn , device , config )
987
+ self .mlp = _IPEXGPT2MLP (module .mlp , device , config )
988
+ if getattr (config , "quantization_config" , None ):
989
+ _remove_hooks_for_ipex (self , True )
990
+
991
+ def forward (
992
+ self ,
993
+ hidden_states : Optional [Tuple [torch .FloatTensor ]],
994
+ layer_past : Optional [Tuple [torch .Tensor ]] = None ,
995
+ attention_mask : Optional [torch .FloatTensor ] = None ,
996
+ head_mask : Optional [torch .FloatTensor ] = None ,
997
+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
998
+ encoder_attention_mask : Optional [torch .FloatTensor ] = None ,
999
+ use_cache : Optional [bool ] = False ,
1000
+ output_attentions : Optional [bool ] = False ,
1001
+ ** kwargs ,
1002
+ ) -> Union [Tuple [torch .Tensor ], Optional [Tuple [torch .Tensor , Tuple [torch .FloatTensor , ...]]]]:
1003
+ residual = hidden_states
1004
+ hidden_states = self .ln_1 (hidden_states )
1005
+ attn_outputs = self .attn (
1006
+ hidden_states ,
1007
+ layer_past = layer_past ,
1008
+ attention_mask = attention_mask ,
1009
+ head_mask = head_mask ,
1010
+ use_cache = use_cache ,
1011
+ output_attentions = output_attentions ,
1012
+ ** kwargs ,
1013
+ )
1014
+ attn_output = attn_outputs [0 ] # output_attn: a, present, (attentions)
1015
+ outputs = attn_outputs [1 :]
1016
+ # residual connection
1017
+ if hasattr (self .attn , "linear_add" ):
1018
+ hidden_states = self .attn .linear_add (attn_output , residual )
1019
+ else :
1020
+ hidden_states = attn_output + residual
1021
+
1022
+ if encoder_hidden_states is not None :
1023
+ # add one self-attention block for cross-attention
1024
+ if not hasattr (self , "crossattention" ):
1025
+ raise ValueError (
1026
+ f"If `encoder_hidden_states` are passed, { self } has to be instantiated with "
1027
+ "cross-attention layers by setting `config.add_cross_attention=True`"
1028
+ )
1029
+ residual = hidden_states
1030
+ hidden_states = self .ln_cross_attn (hidden_states )
1031
+ cross_attn_outputs = self .crossattention (
1032
+ hidden_states ,
1033
+ attention_mask = attention_mask ,
1034
+ head_mask = head_mask ,
1035
+ encoder_hidden_states = encoder_hidden_states ,
1036
+ encoder_attention_mask = encoder_attention_mask ,
1037
+ output_attentions = output_attentions ,
1038
+ ** kwargs ,
1039
+ )
1040
+ attn_output = cross_attn_outputs [0 ]
1041
+ # residual connection
1042
+ hidden_states = residual + attn_output
1043
+ outputs = outputs + cross_attn_outputs [2 :] # add cross attentions if we output attention weights
1044
+
1045
+ residual = hidden_states
1046
+ hidden_states = self .ln_2 (hidden_states )
1047
+ feed_forward_hidden_states = self .mlp (hidden_states )
1048
+ # residual connection
1049
+ if hasattr (self .mlp , "linear_add" ):
1050
+ hidden_states = self .mlp .linear_add (feed_forward_hidden_states , residual )
1051
+ else :
1052
+ hidden_states = residual + feed_forward_hidden_states
1053
+
1054
+ if use_cache :
1055
+ outputs = (hidden_states ,) + outputs
1056
+ else :
1057
+ outputs = (hidden_states ,) + outputs [1 :]
1058
+
1059
+ return outputs # hidden_states, present, (attentions, cross_attentions)
1060
+
1061
+
1051
1062
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524
1052
1063
class _IPEXIntermediate (nn .Module ):
1053
1064
def __init__ (self , module , device , config ):
0 commit comments