-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable quant model support #1074
Changes from all commits
3888824
f9fa807
202df43
4488073
16fecf8
de501f4
4225bf0
2ac7ecf
24b988c
5c4f9a1
46b93a4
82d39ce
7dc08da
30027ff
314db04
87656ca
9a7e931
b0cec9c
94cf35d
9af46d1
18b2a6a
9f6db33
6d8a969
dd811f9
d91eefb
f094cad
dab4a78
6bf3b8b
356d51d
d1eee87
3aece6a
57e3c27
8f6ba5c
8870714
5828fc0
c616d57
e1715b8
e88faf2
882f2b2
d8208c7
80b9ccb
f05fb2f
bd8e870
e471be3
96f4622
32bf0a1
ad3467b
4a21d26
fb8002c
2ad2371
3c2ddef
757ea8c
0a6ab0f
8c4884b
7fa23a5
5386bbe
c5f5d16
41513f0
8e1caa2
f73c08d
f51777b
f64b251
778bf15
6ba1895
88dba29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
|
||
from transformers.models.bert.modeling_bert import BertIntermediate | ||
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel | ||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model | ||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model | ||
from transformers.models.llama.modeling_llama import ( | ||
LlamaDecoderLayer, | ||
LlamaModel, | ||
|
@@ -32,13 +32,11 @@ | |
|
||
from .modeling_utils import ( | ||
_IPEX_MINIMUM_VERSION_FOR_PATCHING, | ||
_IPEXGPT2MLP, | ||
_falcon_model_forward, | ||
_gpt2_block_forward, | ||
_gpt2_model_forward, | ||
_ipex_rms_layer_norm_forward, | ||
_IPEXFalconDecoderLayer, | ||
_IPEXGPT2Attention, | ||
_IPEXGPT2Block, | ||
_IPEXIntermediate, | ||
_IPEXLlamaDecoderLayer, | ||
_IPEXQwen2DecoderLayer, | ||
|
@@ -66,12 +64,12 @@ def convert_functions(m, target_m, new_function_name, new_function): | |
convert_functions(sub_m, target_m, new_function_name, new_function) | ||
|
||
|
||
def convert_class(m, target_m, new_class, config=None): | ||
def convert_class(m, target_m, new_class, device, config): | ||
for name, sub_m in m.named_children(): | ||
if isinstance(sub_m, target_m): | ||
new_m = new_class(sub_m, config) | ||
new_m = new_class(sub_m, device, config) | ||
setattr(m, name, new_m) | ||
convert_class(sub_m, target_m, new_class, config) | ||
convert_class(sub_m, target_m, new_class, device, config) | ||
|
||
|
||
def patch_op(m, target_m, new_op_name, new_op): | ||
|
@@ -89,7 +87,7 @@ def _patch_llama_model(model): | |
""" | ||
convert_functions(model, LlamaModel, "forward", _llama_model_forward) | ||
convert_functions(model, LlamaRMSNorm, "forward", _ipex_rms_layer_norm_forward) | ||
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config) | ||
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.device, model.config) | ||
return model | ||
|
||
|
||
|
@@ -105,21 +103,20 @@ def _patch_falcon_model(model): | |
setattr(model.config, "num_key_value_heads", num_key_value_heads) | ||
convert_functions(model, FalconModel, "forward", _falcon_model_forward) | ||
replace_customized_linear_with_linear(model) | ||
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config) | ||
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.device, model.config) | ||
return model | ||
|
||
|
||
def _patch_gpt2_model(model): | ||
""" | ||
Patch gpt2 model: | ||
1. Use IPEX paged attention | ||
2. Linear fusion with (Linear + Add) | ||
""" | ||
num_key_value_heads = model.config.num_attention_heads | ||
setattr(model.config, "num_key_value_heads", num_key_value_heads) | ||
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward) | ||
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward) | ||
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config) | ||
convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.config) | ||
convert_class(model, GPT2Block, _IPEXGPT2Block, model.device, model.config) | ||
Comment on lines
-120
to
+119
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why no longer patching the mlp and attention here ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because they are in the _IPEXGPT2Block |
||
return model | ||
|
||
|
||
|
@@ -131,7 +128,7 @@ def _patch_qwen2_model(model): | |
""" | ||
convert_functions(model, Qwen2Model, "forward", _qwen2_model_forward) | ||
convert_functions(model, Qwen2RMSNorm, "forward", _ipex_rms_layer_norm_forward) | ||
convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.config) | ||
convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.device, model.config) | ||
return model | ||
|
||
|
||
|
@@ -140,7 +137,7 @@ def _patch_bert_model(model): | |
Patch bert model: | ||
1. Linear fusion with Linear + Gelu | ||
""" | ||
convert_class(model, BertIntermediate, _IPEXIntermediate) | ||
convert_class(model, BertIntermediate, _IPEXIntermediate, model.device, model.config) | ||
return model | ||
|
||
|
||
|
@@ -149,7 +146,7 @@ def _patch_vit_model(model): | |
Patch vit model: | ||
1. Linear fusion with Linear + Gelu | ||
""" | ||
convert_class(model, ViTIntermediate, _IPEXIntermediate) | ||
convert_class(model, ViTIntermediate, _IPEXIntermediate, model.device, model.config) | ||
return model | ||
|
||
|
||
|
Large diffs are not rendered by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no luck with autoawq ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the autoawq installation have some issues, I will figure out why this env cannot install autoawq. But the tests passed in my local env which have the autoawq installed.