Skip to content

Commit 31a3e53

Browse files
committed
Fixed load issue for woq model and update docs
Signed-off-by: Cheng, Penghui <penghui.cheng@intel.com>
1 parent 0c44e0b commit 31a3e53

File tree

5 files changed

+155
-17
lines changed

5 files changed

+155
-17
lines changed

docs/source/optimization_inc.mdx

+27
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,33 @@ mpirun -np <number_of_processes> <RUN_CMD>
126126

127127
Please refer to INC [documentation](https://github.com/intel/neural-compressor/blob/master/docs/source/tuning_strategies.md#distributed-tuning) and [text-classification](https://github.com/huggingface/optimum-intel/tree/main/examples/neural_compressor/text-classification) example for more details.
128128

129+
## Weight-only quantization
130+
As large language models (LLMs) become more prevalent, there is a growing need for new and improved quantization methods that can meet the computational demands of these modern architectures while maintaining the accuracy. Compared to normal quantization like W8A8, weight only quantization is probably a better trade-off to balance the performance and the accuracy. Up to now, we support "GPTQ" and "RTN" method.
131+
132+
```python
133+
from intel_extension_for_transformers.transformers.utils.config import GPTQConfig, RtnConfig
134+
# for GPTQ method
135+
quantization_config = GPTQConfig(
136+
damp_percent=0.01,
137+
weight_dtype="int4_clip",
138+
)
139+
140+
# for RTN method
141+
quantization_config = RtnConfig(
142+
weight_dtype="int4_clip",
143+
)
144+
quantizer = INCQuantizer.from_pretrained(model)
145+
quantizer.quantize(
146+
quantization_config=quantization_config,
147+
save_directory="output_dir",
148+
calibration_dataset=(
149+
train_dataset if quantization_config.quant_metod == "gptq" else None
150+
),
151+
)
152+
q_model = quantizer._quantized_model
153+
154+
```
155+
Please refer to [example](https://github.com/huggingface/optimum-intel/tree/main/examples/neural_compressor/text-generation).
129156

130157
## During training optimization
131158

examples/neural_compressor/text-generation/README.md

+11-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Based on the script [`run_generation.py`](https://github.com/huggingface/transfo
2020

2121
The original generation task only supported the PyTorch eager model. By calling the `TSModelForCausalLM` class, we can now support a TorchScript model for generation tasks.
2222

23-
This example also allows us to apply different quantization approaches (such as dynamic, static, The example applies post-training static quantization on a gptj model).
23+
This example also allows us to apply different quantization approaches (such as dynamic, static, weight-only and aware-training quantization. The example applies post-training static quantization on a gptj model).
2424

2525
Example usage:
2626
### apply_quantization with post-training static
@@ -45,3 +45,13 @@ python run_generation.py \
4545
--smooth_quant_alpha 0.7 \
4646
--jit
4747
```
48+
49+
### apply_quantization with weight-only quantization
50+
As large language models (LLMs) become more prevalent, there is a growing need for new and improved quantization methods that can meet the computational demands of these modern architectures while maintaining the accuracy. Compared to normal quantization like W8A8, weight only quantization is probably a better trade-off to balance the performance and the accuracy. Up to now, we support "GPTQ" and "RTN" method.
51+
```bash
52+
python run_generation.py \
53+
--model_type=gptj \
54+
--model_name_or_path=EleutherAI/gpt-j-6b \
55+
--apply_quantization \
56+
--quantization_approach weight_only\
57+
```

examples/neural_compressor/text-generation/run_generation.py

+108-10
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@
4545
)
4646

4747
from optimum.intel.neural_compressor import INCModelForCausalLM, INCQuantizer
48+
from optimum.intel.utils.import_utils import (
49+
INTEL_EXTENSION_FOR_TRANSFORMERS_IMPORT_ERROR,
50+
is_intel_extension_for_transformers_available,
51+
)
52+
53+
54+
if is_intel_extension_for_transformers_available():
55+
from intel_extension_for_transformers.transformers.utils.config import GPTQConfig, RtnConfig
4856

4957

5058
logging.basicConfig(
@@ -281,6 +289,69 @@ def main():
281289
)
282290
parser.add_argument("--dataset_name", nargs="?", default="NeelNanda/pile-10k", const="NeelNanda/pile-10k")
283291
parser.add_argument("--calib_iters", default=100, type=int, help="calibration iters.")
292+
parser.add_argument(
293+
"--bits",
294+
default="4",
295+
type=str,
296+
help="Bits number of weight for weight only quantization. 1~8 bits.",
297+
)
298+
parser.add_argument(
299+
"--weight_dtype",
300+
default="int4_clip",
301+
type=str,
302+
help="weight dtype for weight only quantization.",
303+
)
304+
parser.add_argument(
305+
"--group_size",
306+
default=32,
307+
type=int,
308+
help="Group size for weight only quantization. Group_size=[1-N] indicates "
309+
"splitting the input channel elements per group_size. -1 indicates "
310+
"the per-channel quantization per output channel.",
311+
)
312+
parser.add_argument(
313+
"--weight_only_scheme",
314+
default="sym",
315+
type=str,
316+
help="Scheme for weight only quantization. Choose from 'sym' and 'asym'.",
317+
)
318+
parser.add_argument(
319+
"--quantization_methodology",
320+
choices=["rtn", "gptq"],
321+
default="rtn",
322+
type=str,
323+
help="Quantization methodology for weight only quantization. Choose from 'rtn' and 'gptq'.",
324+
)
325+
parser.add_argument(
326+
"--damp_percent",
327+
default=0.01,
328+
type=float,
329+
help="Percentage of Hessian's diagonal values average, which will be added to Hessian's diagonal to increase numerical stability, used for GPTQ quantization",
330+
)
331+
parser.add_argument(
332+
"--gptq_block_size",
333+
default=128,
334+
type=int,
335+
help="Block size. sub weight matrix size to run GPTQ.",
336+
)
337+
parser.add_argument(
338+
"--num_calibration_samples",
339+
default=128,
340+
type=int,
341+
help="Number of examples to use for the GPTQ calibration step."
342+
)
343+
parser.add_argument(
344+
"--use_max_length",
345+
default=False,
346+
type=bool,
347+
help="Set all sequence length to be same length of args.gptq_pad_max_length",
348+
)
349+
parser.add_argument(
350+
"--pad_max_length",
351+
default=2048,
352+
type=int,
353+
help="Calibration dataset sequence max length, this should align with your model config",
354+
)
284355
args = parser.parse_args()
285356

286357
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
@@ -313,6 +384,43 @@ def main():
313384
model.to(args.device)
314385

315386
if args.apply_quantization:
387+
supported_approach = {"static", "dynamic", "weight_only"}
388+
if args.quantization_approach not in supported_approach:
389+
raise ValueError(
390+
f"Unknown quantization approach. Supported approach are {supported_approach}."
391+
f"{args.quantization_approach} was given."
392+
)
393+
if args.quantization_approach == "weight_only":
394+
if not is_intel_extension_for_transformers_available():
395+
raise ImportError(INTEL_EXTENSION_FOR_TRANSFORMERS_IMPORT_ERROR.format("WeightOnly quantization"))
396+
397+
algorithm_args = {
398+
"weight_dtype": args.weight_dtype,
399+
"sym": args.weight_only_scheme == "sym",
400+
"group_size": args.group_size,
401+
}
402+
403+
if args.quantization_methodology == "gptq":
404+
quantization_config = GPTQConfig(
405+
damp_percent=args.damp_percent,
406+
nsamples=args.num_calibration_samples,
407+
blocksize=args.gptq_block_size,
408+
**algorithm_args,
409+
)
410+
else:
411+
quantization_config = RtnConfig(**algorithm_args)
412+
413+
else:
414+
example_inputs = {"input_ids": torch.randint(100, (1, 32)), "attention_mask": torch.ones(1, 32)}
415+
quantization_config = PostTrainingQuantConfig(
416+
approach=args.quantization_approach,
417+
recipes={
418+
"smooth_quant": args.smooth_quant,
419+
"smooth_quant_args": {"alpha": args.smooth_quant_alpha, "folding": True},
420+
},
421+
example_inputs=example_inputs,
422+
)
423+
model.config.return_dict = False
316424
# This is just an example for calibration_fn. If you want to achieve good accuracy,
317425
# you must perform a calibration on your real dataset.
318426
calib_dataset = load_dataset(args.dataset_name, split="train")
@@ -347,16 +455,6 @@ def calibration_fn(p_model):
347455
do_sample=False,
348456
)
349457

350-
example_inputs = {"input_ids": torch.randint(100, (1, 32)), "attention_mask": torch.ones(1, 32)}
351-
quantization_config = PostTrainingQuantConfig(
352-
approach=args.quantization_approach,
353-
recipes={
354-
"smooth_quant": args.smooth_quant,
355-
"smooth_quant_args": {"alpha": args.smooth_quant_alpha, "folding": True},
356-
},
357-
example_inputs=example_inputs,
358-
)
359-
model.config.return_dict = False
360458
quantizer = INCQuantizer.from_pretrained(model, calibration_fn=calibration_fn)
361459
with tempfile.TemporaryDirectory() as tmp_dir:
362460
quantizer.quantize(

optimum/intel/neural_compressor/modeling_base.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,15 @@ def _from_pretrained(
148148

149149
return _BaseQBitsAutoModelClass.from_pretrained(
150150
pretrained_model_name_or_path=model_id,
151-
use_auth_token=use_auth_token,
152-
revision=revision,
153-
force_download=force_download,
154-
cache_dir=cache_dir,
155-
local_files_only=local_files_only,
156-
subfolder=subfolder,
151+
# The following parameters are not supported in itrex1.4 version and will be supported in the next version
152+
# use_auth_token=use_auth_token,
153+
# revision=revision,
154+
# force_download=force_download,
155+
# cache_dir=cache_dir,
156+
# local_files_only=local_files_only,
157+
# subfolder=subfolder,
157158
trust_remote_code=trust_remote_code,
159+
use_neural_speed=False,
158160
**kwargs,
159161
)
160162
except EnvironmentError:

optimum/intel/neural_compressor/quantization.py

+1
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def quantize(
297297
)
298298

299299
self._quantized_model.quantization_config = quantization_config
300+
self._quantized_model.config.quantization_config = quantization_config
300301
self._quantized_model.save_pretrained = types.MethodType(save_low_bit, self._quantized_model)
301302
self._quantized_model.save_pretrained(save_directory)
302303

0 commit comments

Comments
 (0)