Skip to content

Commit 137fa3a

Browse files
authored
Add llm examples to SmoothQuant 3.x API (#1685)
Signed-off-by: Cheng, Zixuan <zixuan.cheng@intel.com>
1 parent 3bb284c commit 137fa3a

File tree

6 files changed

+169
-6
lines changed

6 files changed

+169
-6
lines changed

examples/.config/model_params_pytorch.json

+14
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,13 @@
520520
"batch_size": 1,
521521
"main_script": "run_clm_no_trainer.py"
522522
},
523+
"llama2_7b_ipex":{
524+
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
525+
"dataset_location": "",
526+
"input_model": "",
527+
"main_script": "run_clm_no_trainer.py",
528+
"batch_size": 1
529+
},
523530
"llama2_7b_ipex_sq":{
524531
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
525532
"dataset_location": "",
@@ -548,6 +555,13 @@
548555
"main_script": "run_clm_no_trainer.py",
549556
"batch_size": 1
550557
},
558+
"gpt_j_ipex":{
559+
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
560+
"dataset_location": "",
561+
"input_model": "",
562+
"main_script": "run_clm_no_trainer.py",
563+
"batch_size": 1
564+
},
551565
"gpt_j_ipex_sq":{
552566
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
553567
"dataset_location": "",

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/README.md

+32-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ Here is how to run the scripts:
2121
### GPT-J-6b
2222

2323
#### Quantization
24+
```bash
25+
# "--sq" is used to enable smooth quant
26+
python run_clm_no_trainer.py \
27+
--model EleutherAI/gpt-j-6B \
28+
--quantize \
29+
--sq \
30+
--alpha 1.0 \
31+
--ipex \
32+
--output_dir "saved_results"
33+
```
34+
**Notes**: Smooth quantization here is based on torch.jit. Without past key value in example_inputs, the quantized model cannot be used for text-generation.
35+
2436
```bash
2537
# "--approach weight_only" is used to enable weight only quantization.
2638
# "--woq_algo GPTQ" is used to enable GPTQ algorithms
@@ -62,6 +74,15 @@ python run_clm_no_trainer.py \
6274
#### Quantization
6375

6476
```bash
77+
# "--sq" is used to enable smooth quant
78+
python run_clm_no_trainer.py \
79+
--model facebook/opt-125m \
80+
--quantize \
81+
--sq \
82+
--alpha 0.5 \
83+
--ipex \
84+
--output_dir "saved_results"
85+
6586
# "--approach weight_only" is used to enable weight only quantization.
6687
# "--woq_algo GPTQ" is used to enable GPTQ algorithms
6788
# "--double_quant_type BNB_NF4" is used to enable double quant algorithms
@@ -95,10 +116,20 @@ python run_clm_no_trainer.py \
95116
--double_quant_type "BNB_NF4"
96117
```
97118

98-
### LLAMA2-7b/13b/30b
119+
### LLAMA2-7b/13b/70b
120+
>Note: LLAMA requires IPEX requirements >= 2.1 to get better accuracy.
99121
#### Quantization
100122

101123
```bash
124+
# "--sq" is used to enable smooth quant
125+
python run_clm_no_trainer.py \
126+
--model meta-llama/Llama-2-7b-hf \
127+
--quantize \
128+
--sq \
129+
--alpha 0.8 \
130+
--ipex \
131+
--output_dir "saved_results"
132+
102133
# "--approach weight_only" is used to enable weight only quantization.
103134
# "--double_quant_type BNB_NF4" is used to enable double quant algorithms
104135
# "--woq_algo GPTQ" is used to enable GPTQ algorithms

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py

+56-3
Original file line numberDiff line numberDiff line change
@@ -331,9 +331,62 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
331331
model=user_model, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=(dataloader_for_calibration, )
332332
)
333333
else:
334-
# TODO: smooth quant
335-
print("Only support WeightOnlyQuant now")
334+
if args.sq:
335+
from neural_compressor.torch.quantization import SmoothQuantConfig, quantize
336+
337+
# alpha can be a float number of a list of float number.
338+
args.alpha = args.alpha if args.alpha == "auto" else eval(args.alpha)
339+
if re.search("falcon", user_model.config.model_type):
340+
quant_config = SmoothQuantConfig(alpha=args.alpha, folding=False)
341+
else:
342+
quant_config = SmoothQuantConfig(alpha=args.alpha, folding=True)
343+
344+
if re.search("gpt", user_model.config.model_type):
345+
quant_config.set_local("add", SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32"))
346+
else:
347+
from neural_compressor.torch.quantization import quantize, get_default_static_config, StaticQuantConfig
348+
349+
quant_config = get_default_static_config()
350+
if re.search("gpt", user_model.config.model_type):
351+
quant_config.set_local("add", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
352+
353+
from neural_compressor.torch.algorithms.smooth_quant import move_input_to_device
354+
from tqdm import tqdm
355+
def run_fn(model):
356+
for batch in tqdm(calib_dataloader):
357+
batch = move_input_to_device(batch, device=None)
358+
try:
359+
if isinstance(batch, tuple) or isinstance(batch, list):
360+
model(batch[0])
361+
elif isinstance(batch, dict):
362+
model(**batch)
363+
else:
364+
model(batch)
365+
except ValueError:
366+
pass
367+
return
368+
369+
from utils import get_example_inputs
370+
example_inputs = get_example_inputs(user_model, calib_dataloader)
371+
user_model = quantize(
372+
model=user_model, quant_config=quant_config, example_inputs=example_inputs, run_fn=run_fn
373+
)
374+
user_model.save(args.output_dir)
375+
376+
if args.int8 or args.int8_bf16_mixed:
377+
print("load int8 model")
378+
379+
from neural_compressor.torch.algorithms.static_quant import load
380+
381+
if args.ipex:
382+
user_model = load(os.path.abspath(os.path.expanduser(args.output_dir)))
383+
else:
384+
# TODO: WOQ save&load
385+
print("Int8 model loading does not support WeightOnlyQuant now.")
336386
pass
387+
else:
388+
user_model, _ = get_user_model()
389+
337390

338391
if args.accuracy:
339392
user_model.eval()
@@ -382,4 +435,4 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
382435
print("Accuracy: %.5f" % acc)
383436
print('Throughput: %.3f samples/sec' % (samples / (end - start)))
384437
print('Latency: %.3f ms' % ((end - start) * 1000 / samples))
385-
print('Batch size = %d' % args.batch_size)
438+
print('Batch size = %d' % args.batch_size)

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_quant.sh

+18
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ function run_tuning {
5656
approach="weight_only"
5757
extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length --gptq_percdamp 0.1 --gptq_actorder"
5858
extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K"
59+
elif [ "${topology}" = "opt_125m_ipex" ]; then
60+
model_name_or_path="facebook/opt-125m"
61+
extra_cmd=$extra_cmd" --ipex"
62+
elif [ "${topology}" = "opt_125m_ipex_sq" ]; then
63+
model_name_or_path="facebook/opt-125m"
64+
extra_cmd=$extra_cmd" --ipex --sq --alpha 0.5"
5965
elif [ "${topology}" = "llama2_7b_gptq_int4" ]; then
6066
model_name_or_path="meta-llama/Llama-2-7b-hf"
6167
approach="weight_only"
@@ -70,6 +76,12 @@ function run_tuning {
7076
approach="weight_only"
7177
extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length"
7278
extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K"
79+
elif [ "${topology}" = "llama2_7b_ipex" ]; then
80+
model_name_or_path="meta-llama/Llama-2-7b-hf"
81+
extra_cmd=$extra_cmd" --ipex"
82+
elif [ "${topology}" = "llama2_7b_ipex_sq" ]; then
83+
model_name_or_path="meta-llama/Llama-2-7b-hf"
84+
extra_cmd=$extra_cmd" --ipex --sq --alpha 0.8"
7385
elif [ "${topology}" = "gpt_j_woq_rtn_int4" ]; then
7486
model_name_or_path="EleutherAI/gpt-j-6b"
7587
approach="weight_only"
@@ -98,6 +110,12 @@ function run_tuning {
98110
approach="weight_only"
99111
extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_use_mse_search --gptq_use_max_length"
100112
extra_cmd=$extra_cmd" --double_quant_type GGML_TYPE_Q4_K"
113+
elif [ "${topology}" = "gpt_j_ipex" ]; then
114+
model_name_or_path="EleutherAI/gpt-j-6b"
115+
extra_cmd=$extra_cmd" --ipex"
116+
elif [ "${topology}" = "gpt_j_ipex_sq" ]; then
117+
model_name_or_path="EleutherAI/gpt-j-6b"
118+
extra_cmd=$extra_cmd" --ipex --sq --alpha 1.0"
101119
fi
102120

103121
python -u run_clm_no_trainer.py \

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/utils.py

+48-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import random
22
import torch
3+
from collections import UserDict
4+
from packaging.version import Version
35
from neural_compressor.common import logger
6+
from neural_compressor.torch.utils import get_torch_version
47

58
class DataloaderPreprocessor:
69
def __init__(self, dataloader_original, use_max_length=False, max_seq_length=2048, nsamples=128) -> None:
@@ -143,4 +146,48 @@ def obtain_first_n_samples_fulllength(self, seed=0):
143146
logger.warning(
144147
f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, \
145148
but only {len(self.dataloader)} samples are found. Please use smaller 'self.max_seq_length' value."
146-
)
149+
)
150+
151+
152+
def get_example_inputs(model, dataloader):
153+
version = get_torch_version()
154+
from neural_compressor.torch.algorithms.smooth_quant import move_input_to_device
155+
156+
# Suggest set dataloader like calib_dataloader
157+
if dataloader is None:
158+
return None
159+
device = next(model.parameters()).device
160+
try:
161+
for idx, (input, label) in enumerate(dataloader):
162+
input = move_input_to_device(input, device)
163+
if isinstance(input, (dict, UserDict)): # pragma: no cover
164+
assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0"
165+
if "label" in input.keys():
166+
input.pop("label")
167+
if version.release <= Version("2.0.1").release:
168+
return tuple(input.values())
169+
else:
170+
return dict(input)
171+
if isinstance(input, (list, tuple)):
172+
return tuple(input)
173+
if isinstance(input, torch.Tensor):
174+
return input
175+
break
176+
except Exception as e: # pragma: no cover
177+
for idx, input in enumerate(dataloader):
178+
input = move_input_to_device(input, device)
179+
if isinstance(input, (dict, UserDict)): # pragma: no cover
180+
assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0"
181+
if "label" in input.keys():
182+
input.pop("label")
183+
if version.release <= Version("2.0.1").release:
184+
return tuple(input.values())
185+
else:
186+
return dict(input)
187+
if isinstance(input, list) or isinstance(input, tuple):
188+
return tuple(input)
189+
if isinstance(input, torch.Tensor):
190+
return input
191+
break
192+
if idx == 0:
193+
assert False, "Please checkout the example_inputs format."

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ python run_clm_no_trainer.py \
128128
# to validate int8 model generated with `--sq`, please remove "--approach weight_only".
129129
# to validate the int8 model quantized with ipex, please include "--ipex".
130130
```
131-
### LLAMA2-7b/13b/30b
131+
### LLAMA2-7b/13b/70b
132132
>Note: LLAMA requires IPEX requirements >= 2.1 to get better accuracy.
133133
#### Quantization
134134

0 commit comments

Comments
 (0)