46
46
LinearAdd ,
47
47
LinearAddAdd ,
48
48
LinearGelu ,
49
+ LinearNewGelu ,
49
50
PagedAttention ,
50
51
)
51
52
@@ -557,7 +558,10 @@ def _gpt2_block_forward(
557
558
attn_output = attn_outputs [0 ] # output_attn: a, present, (attentions)
558
559
outputs = attn_outputs [1 :]
559
560
# residual connection
560
- hidden_states = attn_output + residual
561
+ if hasattr (self .attn , "linear_add" ):
562
+ hidden_states = self .attn .linear_add (attn_output , residual )
563
+ else :
564
+ hidden_states = attn_output + residual
561
565
562
566
if encoder_hidden_states is not None :
563
567
# add one self-attention block for cross-attention
@@ -586,7 +590,10 @@ def _gpt2_block_forward(
586
590
hidden_states = self .ln_2 (hidden_states )
587
591
feed_forward_hidden_states = self .mlp (hidden_states )
588
592
# residual connection
589
- hidden_states = residual + feed_forward_hidden_states
593
+ if hasattr (self .mlp , "linear_add" ):
594
+ hidden_states = self .mlp .linear_add (feed_forward_hidden_states , residual )
595
+ else :
596
+ hidden_states = residual + feed_forward_hidden_states
590
597
591
598
if use_cache :
592
599
outputs = (hidden_states ,) + outputs
@@ -780,6 +787,13 @@ def __init__(self, module, config) -> None:
780
787
self .c_proj_linear = nn .Linear (self .c_proj .weight .shape [0 ], self .c_proj .weight .shape [1 ])
781
788
self .c_proj_linear .weight = nn .Parameter (self .c_proj .weight .t ())
782
789
self .c_proj_linear .bias = self .c_proj .bias
790
+ if self .module_device .type == "cpu" :
791
+ if self .c_proj_linear not in ["LinearAllreduce" ]:
792
+ self .linear_add = LinearAdd (self .c_proj_linear )
793
+
794
+ elif self .module_device .type == "xpu" :
795
+ if self .c_proj_linear not in ["LinearAllreduce" ]:
796
+ self .linear_add = XPULinearAdd (self .c_proj_linear )
783
797
784
798
def qkv_gemm (self , hidden_states ):
785
799
query , key , value = self .c_attn_linear (hidden_states ).split (self .split_size , dim = - 1 )
@@ -795,7 +809,8 @@ def postprocess_attention_output(self, attn_output):
795
809
if self .use_sdpa :
796
810
attn_output = attn_output .transpose (1 , 2 ).contiguous ()
797
811
attn_output = attn_output .reshape (- 1 , attn_output .shape [- 2 ] * attn_output .shape [- 1 ])
798
- attn_output = self .c_proj (attn_output )
812
+ if not hasattr (self , "linear_add" ):
813
+ attn_output = self .c_proj (attn_output )
799
814
return attn_output
800
815
801
816
@@ -866,6 +881,40 @@ def forward(
866
881
return output
867
882
868
883
884
+ class _IPEXGPT2MLP (nn .Module ):
885
+ def __init__ (self , module , config ) -> None :
886
+ super ().__init__ ()
887
+ _setattr_from_module (self , module )
888
+ self .config = config
889
+ self .module_device = next (module .parameters ()).device
890
+ self .c_fc_linear = nn .Linear (self .c_fc .weight .shape [0 ], self .c_fc .weight .shape [1 ])
891
+ self .c_fc_linear .weight = nn .Parameter (self .c_fc .weight .t ())
892
+ self .c_fc_linear .bias = self .c_fc .bias
893
+ self .c_proj_linear = nn .Linear (self .c_proj .weight .shape [0 ], self .c_proj .weight .shape [1 ])
894
+ self .c_proj_linear .weight = nn .Parameter (self .c_proj .weight .t ())
895
+ self .c_proj_linear .bias = self .c_proj .bias
896
+ if self .module_device .type == "cpu" :
897
+ self .linear_new_gelu = LinearNewGelu (self .c_fc_linear )
898
+
899
+ if self .module_device .type == "cpu" :
900
+ if self .c_proj_linear not in ["LinearAllreduce" ]:
901
+ self .linear_add = LinearAdd (self .c_proj_linear )
902
+
903
+ elif self .module_device .type == "xpu" :
904
+ if self .c_proj_linear not in ["LinearAllreduce" ]:
905
+ self .linear_add = XPULinearAdd (self .c_proj_linear )
906
+
907
+ def forward (self , hidden_states : Optional [Tuple [torch .FloatTensor ]]) -> torch .FloatTensor :
908
+ if hasattr (self , "linear_new_gelu" ):
909
+ hidden_states = self .linear_new_gelu (hidden_states )
910
+ else :
911
+ hidden_states = self .c_fc (hidden_states )
912
+ hidden_states = self .act (hidden_states )
913
+ if not hasattr (self , "linear_add" ):
914
+ hidden_states = self .c_proj (hidden_states )
915
+ return hidden_states
916
+
917
+
869
918
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
870
919
class _IPEXLlamaDecoderLayer (nn .Module ):
871
920
def __init__ (self , module , config ):
0 commit comments