Skip to content

Commit 3fdb3a5

Browse files
committed
fix non patch path
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 72ac9e6 commit 3fdb3a5

File tree

2 files changed

+55
-4
lines changed

2 files changed

+55
-4
lines changed

optimum/exporters/ipex/model_patcher.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from transformers.models.bert.modeling_bert import BertIntermediate
1616
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
17-
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
17+
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
1818
from transformers.models.llama.modeling_llama import (
1919
LlamaDecoderLayer,
2020
LlamaModel,
@@ -27,6 +27,7 @@
2727

2828
from .modeling_utils import (
2929
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
30+
_IPEXGPT2MLP,
3031
_falcon_model_forward,
3132
_gpt2_block_forward,
3233
_gpt2_model_forward,
@@ -111,6 +112,7 @@ def _patch_gpt2_model(model):
111112
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
112113
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
113114
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
115+
convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.config)
114116
return model
115117

116118

optimum/exporters/ipex/modeling_utils.py

+52-3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
LinearAdd,
4747
LinearAddAdd,
4848
LinearGelu,
49+
LinearNewGelu,
4950
PagedAttention,
5051
)
5152

@@ -557,7 +558,10 @@ def _gpt2_block_forward(
557558
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
558559
outputs = attn_outputs[1:]
559560
# residual connection
560-
hidden_states = attn_output + residual
561+
if hasattr(self.attn, "linear_add"):
562+
hidden_states = self.attn.linear_add(attn_output, residual)
563+
else:
564+
hidden_states = attn_output + residual
561565

562566
if encoder_hidden_states is not None:
563567
# add one self-attention block for cross-attention
@@ -586,7 +590,10 @@ def _gpt2_block_forward(
586590
hidden_states = self.ln_2(hidden_states)
587591
feed_forward_hidden_states = self.mlp(hidden_states)
588592
# residual connection
589-
hidden_states = residual + feed_forward_hidden_states
593+
if hasattr(self.mlp, "linear_add"):
594+
hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual)
595+
else:
596+
hidden_states = residual + feed_forward_hidden_states
590597

591598
if use_cache:
592599
outputs = (hidden_states,) + outputs
@@ -780,6 +787,13 @@ def __init__(self, module, config) -> None:
780787
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
781788
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
782789
self.c_proj_linear.bias = self.c_proj.bias
790+
if self.module_device.type == "cpu":
791+
if self.c_proj_linear not in ["LinearAllreduce"]:
792+
self.linear_add = LinearAdd(self.c_proj_linear)
793+
794+
elif self.module_device.type == "xpu":
795+
if self.c_proj_linear not in ["LinearAllreduce"]:
796+
self.linear_add = XPULinearAdd(self.c_proj_linear)
783797

784798
def qkv_gemm(self, hidden_states):
785799
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
@@ -795,7 +809,8 @@ def postprocess_attention_output(self, attn_output):
795809
if self.use_sdpa:
796810
attn_output = attn_output.transpose(1, 2).contiguous()
797811
attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1])
798-
attn_output = self.c_proj(attn_output)
812+
if not hasattr(self, "linear_add"):
813+
attn_output = self.c_proj(attn_output)
799814
return attn_output
800815

801816

@@ -866,6 +881,40 @@ def forward(
866881
return output
867882

868883

884+
class _IPEXGPT2MLP(nn.Module):
885+
def __init__(self, module, config) -> None:
886+
super().__init__()
887+
_setattr_from_module(self, module)
888+
self.config = config
889+
self.module_device = next(module.parameters()).device
890+
self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1])
891+
self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t())
892+
self.c_fc_linear.bias = self.c_fc.bias
893+
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
894+
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
895+
self.c_proj_linear.bias = self.c_proj.bias
896+
if self.module_device.type == "cpu":
897+
self.linear_new_gelu = LinearNewGelu(self.c_fc_linear)
898+
899+
if self.module_device.type == "cpu":
900+
if self.c_proj_linear not in ["LinearAllreduce"]:
901+
self.linear_add = LinearAdd(self.c_proj_linear)
902+
903+
elif self.module_device.type == "xpu":
904+
if self.c_proj_linear not in ["LinearAllreduce"]:
905+
self.linear_add = XPULinearAdd(self.c_proj_linear)
906+
907+
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
908+
if hasattr(self, "linear_new_gelu"):
909+
hidden_states = self.linear_new_gelu(hidden_states)
910+
else:
911+
hidden_states = self.c_fc(hidden_states)
912+
hidden_states = self.act(hidden_states)
913+
if not hasattr(self, "linear_add"):
914+
hidden_states = self.c_proj(hidden_states)
915+
return hidden_states
916+
917+
869918
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
870919
class _IPEXLlamaDecoderLayer(nn.Module):
871920
def __init__(self, module, config):

0 commit comments

Comments
 (0)