Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 6bf3b8b

Browse files
committedJan 21, 2025·
fix gpt2 quant model
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent dab4a78 commit 6bf3b8b

File tree

2 files changed

+122
-114
lines changed

2 files changed

+122
-114
lines changed
 

‎optimum/exporters/ipex/model_patcher.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from transformers.models.bert.modeling_bert import BertIntermediate
1616
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
17-
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
17+
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
1818
from transformers.models.llama.modeling_llama import (
1919
LlamaDecoderLayer,
2020
LlamaModel,
@@ -27,13 +27,11 @@
2727

2828
from .modeling_utils import (
2929
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
30-
_IPEXGPT2MLP,
3130
_falcon_model_forward,
32-
_gpt2_block_forward,
3331
_gpt2_model_forward,
3432
_ipex_rms_layer_norm_forward,
3533
_IPEXFalconDecoderLayer,
36-
_IPEXGPT2Attention,
34+
_IPEXGPT2Block,
3735
_IPEXIntermediate,
3836
_IPEXLlamaDecoderLayer,
3937
_llama_model_forward,
@@ -106,13 +104,12 @@ def _patch_gpt2_model(model):
106104
"""
107105
Patch gpt2 model:
108106
1. Use IPEX paged attention
107+
2. Linear fusion with (Linear + Add)
109108
"""
110109
num_key_value_heads = model.config.num_attention_heads
111110
setattr(model.config, "num_key_value_heads", num_key_value_heads)
112111
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
113-
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
114-
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.device, model.config)
115-
convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.device, model.config)
112+
convert_class(model, GPT2Block, _IPEXGPT2Block, model.device, model.config)
116113
return model
117114

118115

‎optimum/exporters/ipex/modeling_utils.py

+118-107
Original file line numberDiff line numberDiff line change
@@ -558,78 +558,6 @@ def _gpt2_model_forward(
558558
)
559559

560560

561-
# To pass input_lens, adapted from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt2/modeling_gpt2.py#L602
562-
def _gpt2_block_forward(
563-
self,
564-
hidden_states: Optional[Tuple[torch.FloatTensor]],
565-
layer_past: Optional[Tuple[torch.Tensor]] = None,
566-
attention_mask: Optional[torch.FloatTensor] = None,
567-
head_mask: Optional[torch.FloatTensor] = None,
568-
encoder_hidden_states: Optional[torch.Tensor] = None,
569-
encoder_attention_mask: Optional[torch.FloatTensor] = None,
570-
use_cache: Optional[bool] = False,
571-
output_attentions: Optional[bool] = False,
572-
**kwargs,
573-
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
574-
residual = hidden_states
575-
hidden_states = self.ln_1(hidden_states)
576-
attn_outputs = self.attn(
577-
hidden_states,
578-
layer_past=layer_past,
579-
attention_mask=attention_mask,
580-
head_mask=head_mask,
581-
use_cache=use_cache,
582-
output_attentions=output_attentions,
583-
**kwargs,
584-
)
585-
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
586-
outputs = attn_outputs[1:]
587-
# residual connection
588-
if hasattr(self.attn, "linear_add"):
589-
hidden_states = self.attn.linear_add(attn_output, residual)
590-
else:
591-
hidden_states = attn_output + residual
592-
593-
if encoder_hidden_states is not None:
594-
# add one self-attention block for cross-attention
595-
if not hasattr(self, "crossattention"):
596-
raise ValueError(
597-
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
598-
"cross-attention layers by setting `config.add_cross_attention=True`"
599-
)
600-
residual = hidden_states
601-
hidden_states = self.ln_cross_attn(hidden_states)
602-
cross_attn_outputs = self.crossattention(
603-
hidden_states,
604-
attention_mask=attention_mask,
605-
head_mask=head_mask,
606-
encoder_hidden_states=encoder_hidden_states,
607-
encoder_attention_mask=encoder_attention_mask,
608-
output_attentions=output_attentions,
609-
**kwargs,
610-
)
611-
attn_output = cross_attn_outputs[0]
612-
# residual connection
613-
hidden_states = residual + attn_output
614-
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
615-
616-
residual = hidden_states
617-
hidden_states = self.ln_2(hidden_states)
618-
feed_forward_hidden_states = self.mlp(hidden_states)
619-
# residual connection
620-
if hasattr(self.mlp, "linear_add"):
621-
hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual)
622-
else:
623-
hidden_states = residual + feed_forward_hidden_states
624-
625-
if use_cache:
626-
outputs = (hidden_states,) + outputs
627-
else:
628-
outputs = (hidden_states,) + outputs[1:]
629-
630-
return outputs # hidden_states, present, (attentions, cross_attentions)
631-
632-
633561
class _IPEXAttention(nn.Module):
634562
def __init__(self, module, device, config) -> None:
635563
super().__init__()
@@ -844,26 +772,27 @@ class _IPEXGPT2Attention(_IPEXAttention):
844772
def __init__(self, module, device, config) -> None:
845773
self.num_key_value_heads = config.num_key_value_heads
846774
super().__init__(module, device, config)
847-
if getattr(config, "quantization_config", None):
848-
_remove_hooks_for_ipex(self, True)
849-
850775
_setattr_from_module(self, module)
851-
self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1])
852-
self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t())
853-
self.c_attn_linear.bias = self.c_attn.bias
854-
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
855-
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
856-
self.c_proj_linear.bias = self.c_proj.bias
857-
if self.module_device.type == "cpu":
858-
if self.c_proj_linear not in ["LinearAllreduce"]:
859-
self.linear_add = LinearAdd(self.c_proj_linear)
860-
861-
elif self.module_device.type == "xpu":
862-
if self.c_proj_linear not in ["LinearAllreduce"]:
863-
self.linear_add = XPULinearAdd(self.c_proj_linear)
776+
if getattr(config, "quantization_config", None) is None:
777+
self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1])
778+
self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t())
779+
self.c_attn_linear.bias = self.c_attn.bias
780+
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
781+
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
782+
self.c_proj_linear.bias = self.c_proj.bias
783+
if self.module_device.type == "cpu":
784+
if self.c_proj_linear not in ["LinearAllreduce"]:
785+
self.linear_add = LinearAdd(self.c_proj_linear)
786+
787+
elif self.module_device.type == "xpu":
788+
if self.c_proj_linear not in ["LinearAllreduce"]:
789+
self.linear_add = XPULinearAdd(self.c_proj_linear)
864790

865791
def qkv_gemm(self, hidden_states):
866-
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
792+
if hasattr(self, "c_attn_linear"):
793+
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
794+
else:
795+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1)
867796
query = query.view(-1, self.num_heads, self.head_dim)
868797
key = key.view(-1, self.num_heads, self.head_dim)
869798
value = value.view(-1, self.num_heads, self.head_dim)
@@ -951,27 +880,29 @@ def forward(
951880

952881

953882
class _IPEXGPT2MLP(nn.Module):
954-
def __init__(self, module, config) -> None:
883+
def __init__(self, module, device, config) -> None:
955884
super().__init__()
956885
_setattr_from_module(self, module)
957886
self.config = config
958-
self.module_device = next(module.parameters()).device
959-
self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1])
960-
self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t())
961-
self.c_fc_linear.bias = self.c_fc.bias
962-
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
963-
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
964-
self.c_proj_linear.bias = self.c_proj.bias
965-
if self.module_device.type == "cpu":
966-
self.linear_new_gelu = LinearNewGelu(self.c_fc_linear)
967-
968-
if self.module_device.type == "cpu":
969-
if self.c_proj_linear not in ["LinearAllreduce"]:
970-
self.linear_add = LinearAdd(self.c_proj_linear)
971-
972-
elif self.module_device.type == "xpu":
973-
if self.c_proj_linear not in ["LinearAllreduce"]:
974-
self.linear_add = XPULinearAdd(self.c_proj_linear)
887+
self.module_device = device
888+
889+
if getattr(config, "quantization_config", None) is None:
890+
self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1])
891+
self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t())
892+
self.c_fc_linear.bias = self.c_fc.bias
893+
self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1])
894+
self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t())
895+
self.c_proj_linear.bias = self.c_proj.bias
896+
if self.module_device.type == "cpu":
897+
self.linear_new_gelu = LinearNewGelu(self.c_fc_linear)
898+
899+
if self.module_device.type == "cpu":
900+
if self.c_proj_linear not in ["LinearAllreduce"]:
901+
self.linear_add = LinearAdd(self.c_proj_linear)
902+
903+
elif self.module_device.type == "xpu":
904+
if self.c_proj_linear not in ["LinearAllreduce"]:
905+
self.linear_add = XPULinearAdd(self.c_proj_linear)
975906

976907
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
977908
if hasattr(self, "linear_new_gelu"):
@@ -1048,6 +979,86 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
1048979
return outputs
1049980

1050981

982+
class _IPEXGPT2Block(nn.Module):
983+
def __init__(self, module, device, config):
984+
super().__init__()
985+
_setattr_from_module(self, module)
986+
self.attn = _IPEXGPT2Attention(module.attn, device, config)
987+
self.mlp = _IPEXGPT2MLP(module.mlp, device, config)
988+
if getattr(config, "quantization_config", None):
989+
_remove_hooks_for_ipex(self, True)
990+
991+
def forward(
992+
self,
993+
hidden_states: Optional[Tuple[torch.FloatTensor]],
994+
layer_past: Optional[Tuple[torch.Tensor]] = None,
995+
attention_mask: Optional[torch.FloatTensor] = None,
996+
head_mask: Optional[torch.FloatTensor] = None,
997+
encoder_hidden_states: Optional[torch.Tensor] = None,
998+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
999+
use_cache: Optional[bool] = False,
1000+
output_attentions: Optional[bool] = False,
1001+
**kwargs,
1002+
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
1003+
residual = hidden_states
1004+
hidden_states = self.ln_1(hidden_states)
1005+
attn_outputs = self.attn(
1006+
hidden_states,
1007+
layer_past=layer_past,
1008+
attention_mask=attention_mask,
1009+
head_mask=head_mask,
1010+
use_cache=use_cache,
1011+
output_attentions=output_attentions,
1012+
**kwargs,
1013+
)
1014+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
1015+
outputs = attn_outputs[1:]
1016+
# residual connection
1017+
if hasattr(self.attn, "linear_add"):
1018+
hidden_states = self.attn.linear_add(attn_output, residual)
1019+
else:
1020+
hidden_states = attn_output + residual
1021+
1022+
if encoder_hidden_states is not None:
1023+
# add one self-attention block for cross-attention
1024+
if not hasattr(self, "crossattention"):
1025+
raise ValueError(
1026+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
1027+
"cross-attention layers by setting `config.add_cross_attention=True`"
1028+
)
1029+
residual = hidden_states
1030+
hidden_states = self.ln_cross_attn(hidden_states)
1031+
cross_attn_outputs = self.crossattention(
1032+
hidden_states,
1033+
attention_mask=attention_mask,
1034+
head_mask=head_mask,
1035+
encoder_hidden_states=encoder_hidden_states,
1036+
encoder_attention_mask=encoder_attention_mask,
1037+
output_attentions=output_attentions,
1038+
**kwargs,
1039+
)
1040+
attn_output = cross_attn_outputs[0]
1041+
# residual connection
1042+
hidden_states = residual + attn_output
1043+
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
1044+
1045+
residual = hidden_states
1046+
hidden_states = self.ln_2(hidden_states)
1047+
feed_forward_hidden_states = self.mlp(hidden_states)
1048+
# residual connection
1049+
if hasattr(self.mlp, "linear_add"):
1050+
hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual)
1051+
else:
1052+
hidden_states = residual + feed_forward_hidden_states
1053+
1054+
if use_cache:
1055+
outputs = (hidden_states,) + outputs
1056+
else:
1057+
outputs = (hidden_states,) + outputs[1:]
1058+
1059+
return outputs # hidden_states, present, (attentions, cross_attentions)
1060+
1061+
10511062
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524
10521063
class _IPEXIntermediate(nn.Module):
10531064
def __init__(self, module, device, config):

0 commit comments

Comments
 (0)
Please sign in to comment.