Skip to content

Commit bc1d034

Browse files
Enable quant model support (#1074)
* enable IPEXModelForSeq2SeqLM Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * set static cache Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add tests for IPEXModelForSeq2SeqLM Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add docs Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix readme Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * refactor compile Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix check Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix ruff check Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable quantized model Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add bnb test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add bnb tests in yaml Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * disable bnb tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix gpt2 Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * set actual device Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * assign device when convert class Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix class init Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix ipex attn init Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * rm set device on config Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix mlp class init Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add use_cache param when init generation config Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix gpt2 quant model Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix falcon linear fusion Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix falcon Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable awq model test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable bnb test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * remove useless device Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update python to 3.10 on test_ipex Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Apply suggestions from code review * install autoawq Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * install wheel Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix install autoawq Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * rm autoawq Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix concat qkv Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix qwen patch Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix bias Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * rm autoawq test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix style Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
1 parent 8361e45 commit bc1d034

File tree

5 files changed

+296
-187
lines changed

5 files changed

+296
-187
lines changed

.github/workflows/test_ipex.yml

+7-1
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,20 @@ jobs:
3030
- name: Setup Python
3131
uses: actions/setup-python@v5
3232
with:
33-
python-version: 3.9
33+
python-version: "3.10"
3434

3535
- name: Install dependencies
3636
run: |
3737
pip install --upgrade pip
3838
pip install torch==${{ matrix.torch-version }} torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu
3939
pip install .[ipex,tests] transformers[testing]==${{ matrix.transformers-version }} intel_extension_for_pytorch==${{ matrix.torch-version }}
4040
41+
- name: Install bitsandbytes
42+
run: |
43+
git clone --branch multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git
44+
cd bitsandbytes
45+
pip install .
46+
4147
- name: Assert versions
4248
run: |
4349
python -c "import torch; print(torch.__version__); assert torch.__version__.startswith('${{ matrix.torch-version }}'.replace('.*', ''))"

optimum/exporters/ipex/model_patcher.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from transformers.models.bert.modeling_bert import BertIntermediate
1616
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
17-
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
17+
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
1818
from transformers.models.llama.modeling_llama import (
1919
LlamaDecoderLayer,
2020
LlamaModel,
@@ -32,13 +32,11 @@
3232

3333
from .modeling_utils import (
3434
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
35-
_IPEXGPT2MLP,
3635
_falcon_model_forward,
37-
_gpt2_block_forward,
3836
_gpt2_model_forward,
3937
_ipex_rms_layer_norm_forward,
4038
_IPEXFalconDecoderLayer,
41-
_IPEXGPT2Attention,
39+
_IPEXGPT2Block,
4240
_IPEXIntermediate,
4341
_IPEXLlamaDecoderLayer,
4442
_IPEXQwen2DecoderLayer,
@@ -66,12 +64,12 @@ def convert_functions(m, target_m, new_function_name, new_function):
6664
convert_functions(sub_m, target_m, new_function_name, new_function)
6765

6866

69-
def convert_class(m, target_m, new_class, config=None):
67+
def convert_class(m, target_m, new_class, device, config):
7068
for name, sub_m in m.named_children():
7169
if isinstance(sub_m, target_m):
72-
new_m = new_class(sub_m, config)
70+
new_m = new_class(sub_m, device, config)
7371
setattr(m, name, new_m)
74-
convert_class(sub_m, target_m, new_class, config)
72+
convert_class(sub_m, target_m, new_class, device, config)
7573

7674

7775
def patch_op(m, target_m, new_op_name, new_op):
@@ -89,7 +87,7 @@ def _patch_llama_model(model):
8987
"""
9088
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
9189
convert_functions(model, LlamaRMSNorm, "forward", _ipex_rms_layer_norm_forward)
92-
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config)
90+
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.device, model.config)
9391
return model
9492

9593

@@ -105,21 +103,20 @@ def _patch_falcon_model(model):
105103
setattr(model.config, "num_key_value_heads", num_key_value_heads)
106104
convert_functions(model, FalconModel, "forward", _falcon_model_forward)
107105
replace_customized_linear_with_linear(model)
108-
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config)
106+
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.device, model.config)
109107
return model
110108

111109

112110
def _patch_gpt2_model(model):
113111
"""
114112
Patch gpt2 model:
115113
1. Use IPEX paged attention
114+
2. Linear fusion with (Linear + Add)
116115
"""
117116
num_key_value_heads = model.config.num_attention_heads
118117
setattr(model.config, "num_key_value_heads", num_key_value_heads)
119118
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
120-
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
121-
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
122-
convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.config)
119+
convert_class(model, GPT2Block, _IPEXGPT2Block, model.device, model.config)
123120
return model
124121

125122

@@ -131,7 +128,7 @@ def _patch_qwen2_model(model):
131128
"""
132129
convert_functions(model, Qwen2Model, "forward", _qwen2_model_forward)
133130
convert_functions(model, Qwen2RMSNorm, "forward", _ipex_rms_layer_norm_forward)
134-
convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.config)
131+
convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.device, model.config)
135132
return model
136133

137134

@@ -140,7 +137,7 @@ def _patch_bert_model(model):
140137
Patch bert model:
141138
1. Linear fusion with Linear + Gelu
142139
"""
143-
convert_class(model, BertIntermediate, _IPEXIntermediate)
140+
convert_class(model, BertIntermediate, _IPEXIntermediate, model.device, model.config)
144141
return model
145142

146143

@@ -149,7 +146,7 @@ def _patch_vit_model(model):
149146
Patch vit model:
150147
1. Linear fusion with Linear + Gelu
151148
"""
152-
convert_class(model, ViTIntermediate, _IPEXIntermediate)
149+
convert_class(model, ViTIntermediate, _IPEXIntermediate, model.device, model.config)
153150
return model
154151

155152

0 commit comments

Comments
 (0)