Skip to content

Commit ea6fa42

Browse files
[OV] Introduce --quant-mode cli argument enabling full quantization via optimum-cli (#1061)
* Introduce --quant-mode cli argument * Make int8 by default * Add a test * Add documentation * Fix command * Replace 'int8/int8' by 'int8' * Add missing docstring * Add trust_remote_code * Fix condition * Trigger Tests * Trigger Tests
1 parent 0a09651 commit ea6fa42

File tree

6 files changed

+155
-11
lines changed

6 files changed

+155
-11
lines changed

docs/source/openvino/export.mdx

+17-3
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@ Check out the help for more options:
3131

3232
```text
3333
usage: optimum-cli export openvino [-h] -m MODEL [--task TASK] [--framework {pt,tf}] [--trust-remote-code]
34-
[--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}]
34+
[--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}] [--quant-mode {int8}]
3535
[--library {transformers,diffusers,timm,sentence_transformers,open_clip}]
3636
[--cache_dir CACHE_DIR] [--pad-token-id PAD_TOKEN_ID] [--ratio RATIO] [--sym]
3737
[--group-size GROUP_SIZE] [--backup-precision {none,int8_sym,int8_asym}]
3838
[--dataset DATASET] [--all-layers] [--awq] [--scale-estimation] [--gptq]
3939
[--lora-correction] [--sensitivity-metric SENSITIVITY_METRIC]
4040
[--num-samples NUM_SAMPLES] [--disable-stateful] [--disable-convert-tokenizer]
41+
[--smooth-quant-alpha SMOOTH_QUANT_ALPHA]
4142
output
4243

4344
optional arguments:
@@ -66,6 +67,10 @@ Optional arguments:
6667
on your local machine arbitrary code present in the model repository.
6768
--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}
6869
The weight format of the exported model.
70+
--quant-mode {int8}
71+
Quantization precision mode. This is used for applying full model quantization including
72+
activations. The only currently supported choice is 'int8' for int8 quantization of both
73+
weights and activations.
6974
--library {transformers,diffusers,timm,sentence_transformers,open_clip}
7075
The library used to load the model before export. If not provided, will attempt to infer the
7176
local checkpoint's library
@@ -102,8 +107,8 @@ Optional arguments:
102107
weight compression is applied, they are compressed to INT8.
103108
--awq Whether to apply AWQ algorithm. AWQ improves generation quality of INT4-compressed LLMs, but
104109
requires additional time for tuning weights on a calibration dataset. To run AWQ, please also
105-
provide a dataset argument. Note: it is possible that there will be no matching patterns in the
106-
model to apply AWQ, in such case it will be skipped.
110+
provide a dataset argument. Note: it is possible that there will be no matching patterns in
111+
the model to apply AWQ, in such case it will be skipped.
107112
--scale-estimation Indicates whether to apply a scale estimation algorithm that minimizes the L2 error between
108113
the original and compressed layers. Providing a dataset is required to run scale estimation.
109114
Please note, that applying scale estimation takes additional memory and time.
@@ -128,6 +133,9 @@ Optional arguments:
128133
OpenVINO native inference code that expects KV-cache inputs and outputs in the model.
129134
--disable-convert-tokenizer
130135
Do not add converted tokenizer and detokenizer OpenVINO models.
136+
--smooth-quant-alpha SMOOTH_QUANT_ALPHA
137+
SmoothQuant alpha parameter that improves the distribution of activations before MatMul layers
138+
and reduces quantization error. Valid only when activations quantization is enabled.
131139
```
132140

133141
You can also apply fp16, 8-bit or 4-bit weight-only quantization on the Linear, Convolutional and Embedding layers when exporting your model by setting `--weight-format` to respectively `fp16`, `int8` or `int4`.
@@ -158,6 +166,12 @@ Models larger than 1 billion parameters are exported to the OpenVINO format with
158166
</Tip>
159167

160168

169+
Besides weight-only quantization, you can also apply full model quantization including activations by setting `--quant-mode` to `int8`. This will quantize both weights and activations of Linear, Convolutional and some other layers to int8. Currently this is only supported for speech-to-text models. Please see example below.
170+
171+
```bash
172+
optimum-cli export openvino -m openai/whisper-large-v3-turbo --quant-mode int8 --dataset librispeech --num-samples 32 --smooth-quant-alpha 0.9 ./whisper-large-v3-turbo
173+
```
174+
161175
### Decoder models
162176

163177
For models with a decoder, we enable the re-use of past keys and values by default. This allows to avoid recomputing the same intermediate activations at each generation step. To export the model without, you will need to remove the `-with-past` suffix when specifying the task.

optimum/commands/export/openvino.py

+64-5
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ def parse_args_openvino(parser: "ArgumentParser"):
7575
default=None,
7676
help="The weight format of the exported model.",
7777
)
78+
optional_group.add_argument(
79+
"--quant-mode",
80+
type=str,
81+
choices=["int8"],
82+
default=None,
83+
help=(
84+
"Quantization precision mode. This is used for applying full model quantization including activations. "
85+
"The only currently supported choice is 'int8' for int8 quantization of both weights and activations."
86+
),
87+
)
7888
optional_group.add_argument(
7989
"--library",
8090
type=str,
@@ -228,6 +238,15 @@ def parse_args_openvino(parser: "ArgumentParser"):
228238
action="store_true",
229239
help="Do not add converted tokenizer and detokenizer OpenVINO models.",
230240
)
241+
optional_group.add_argument(
242+
"--smooth-quant-alpha",
243+
type=float,
244+
default=None,
245+
help=(
246+
"SmoothQuant alpha parameter that improves the distribution of activations before MatMul layers and "
247+
"reduces quantization error. Valid only when activations quantization is enabled."
248+
),
249+
)
231250

232251

233252
def no_compression_parameter_provided(args):
@@ -252,6 +271,20 @@ def no_compression_parameter_provided(args):
252271
)
253272

254273

274+
def no_quantization_parameter_provided(args):
275+
return all(
276+
(
277+
it is None
278+
for it in (
279+
args.sym,
280+
args.dataset,
281+
args.num_samples,
282+
args.smooth_quant_alpha,
283+
)
284+
)
285+
)
286+
287+
255288
class OVExportCommand(BaseOptimumCLICommand):
256289
COMMAND = CommandInfo(name="openvino", help="Export PyTorch models to OpenVINO IR.")
257290

@@ -291,16 +324,21 @@ def run(self):
291324
else:
292325
library_name = self.args.library
293326

294-
if self.args.weight_format is None:
327+
if self.args.weight_format is None and self.args.quant_mode is None:
295328
ov_config = None
296329
if not no_compression_parameter_provided(self.args):
297330
raise ValueError(
298331
"Some compression parameters are provided, but the weight format is not specified. "
299332
"Please provide it with --weight-format argument."
300333
)
334+
if not no_quantization_parameter_provided(self.args):
335+
raise ValueError(
336+
"Some quantization parameters are provided, but the quantization mode is not specified. "
337+
"Please provide it with --quant-mode argument."
338+
)
301339
elif self.args.weight_format in {"fp16", "fp32"}:
302340
ov_config = OVConfig(dtype=self.args.weight_format)
303-
else:
341+
elif self.args.weight_format is not None:
304342
# For int4 quantization if no parameter is provided, then use the default config if exists
305343
if no_compression_parameter_provided(self.args) and self.args.weight_format == "int4":
306344
quantization_config = get_default_int4_config(self.args.model)
@@ -326,6 +364,21 @@ def run(self):
326364
if quantization_config.get("dataset", None) is not None:
327365
quantization_config["trust_remote_code"] = self.args.trust_remote_code
328366
ov_config = OVConfig(quantization_config=quantization_config)
367+
else:
368+
if self.args.quant_mode != "int8":
369+
raise ValueError("Only 'int8' quantization mode is currently supported.")
370+
371+
quantization_config = {
372+
"weight_format": self.args.quant_mode,
373+
"activation_format": self.args.quant_mode,
374+
"bits": 8,
375+
"sym": self.args.sym or False,
376+
"dataset": self.args.dataset,
377+
"num_samples": self.args.num_samples,
378+
"smooth_quant_alpha": self.args.smooth_quant_alpha,
379+
"trust_remote_code": self.args.trust_remote_code,
380+
}
381+
ov_config = OVConfig(quantization_config=quantization_config)
329382

330383
quantization_config = ov_config.quantization_config if ov_config else None
331384
quantize_with_dataset = quantization_config and getattr(quantization_config, "dataset", None) is not None
@@ -368,17 +421,23 @@ def run(self):
368421
model.save_pretrained(self.args.output)
369422
if not self.args.disable_convert_tokenizer:
370423
maybe_convert_tokenizers(library_name, self.args.output, model, task=task)
371-
elif (task.startswith("text-generation") and quantize_with_dataset) or (
372-
task == "image-text-to-text" and quantization_config is not None
424+
elif (
425+
quantize_with_dataset
426+
and (task.startswith("text-generation") or task == "automatic-speech-recognition")
427+
or (task == "image-text-to-text" and quantization_config is not None)
373428
):
374429
if task.startswith("text-generation"):
375430
from optimum.intel import OVModelForCausalLM
376431

377432
model_cls = OVModelForCausalLM
378-
else:
433+
elif task == "image-text-to-text":
379434
from optimum.intel import OVModelForVisualCausalLM
380435

381436
model_cls = OVModelForVisualCausalLM
437+
else:
438+
from optimum.intel import OVModelForSpeechSeq2Seq
439+
440+
model_cls = OVModelForSpeechSeq2Seq
382441

383442
# In this case, to apply quantization an instance of a model class is required
384443
model = model_cls.from_pretrained(

optimum/intel/openvino/configuration.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def __init__(
266266
tokenizer: Optional[str] = None,
267267
processor: Optional[str] = None,
268268
trust_remote_code: bool = False,
269+
weight_format: Optional[str] = None,
269270
**kwargs,
270271
):
271272
"""
@@ -279,6 +280,18 @@ def __init__(
279280
entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class.
280281
num_samples (`int`, *optional*):
281282
The maximum number of samples composing the calibration dataset.
283+
dataset (`str or List[str]`, *optional*):
284+
The dataset used for data-aware optimization with NNCF.
285+
tokenizer (`str`, *optional*):
286+
The tokenizer used to process the dataset.
287+
processor (`str`, *optional*):
288+
A transformers processor used to process the dataset inputs.
289+
trust_remote_code (`bool`, defaults to `False`):
290+
Allows to use custom code for the modeling hosted in the model repository. This option should only be
291+
set for repositories you trust and in which you have read the code, as it will execute on your local
292+
machine arbitrary code present in the model repository.
293+
weight_format (`str`, *optional*):
294+
Data format weights are compressed to.
282295
"""
283296
self.bits = bits
284297
self.sym = sym
@@ -287,6 +300,7 @@ def __init__(
287300
self.tokenizer = tokenizer
288301
self.processor = processor
289302
self.trust_remote_code = trust_remote_code
303+
self.weight_format = weight_format
290304

291305
if isinstance(ignored_scope, nncf.IgnoredScope):
292306
ignored_scope = ignored_scope.__dict__
@@ -370,7 +384,7 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
370384
scale_estimation (`bool`, *optional*):
371385
Indicates whether to apply a scale estimation algorithm that minimizes the L2 error between the original and
372386
compressed layers. Providing a dataset is required to run scale estimation.
373-
weight_format (`str`, defaults to 'int'):
387+
weight_format (`str`, *optional*):
374388
Data format weights are compressed to. Possible values: ['int4', 'int8', 'mxfp4', 'nf4'].
375389
qptq (`bool`, *optional*):
376390
Whether to apply GPTQ algorithm. GPTQ optimizes compressed weights in a layer-wise fashion to minimize the
@@ -425,14 +439,14 @@ def __init__(
425439
tokenizer=tokenizer,
426440
processor=processor,
427441
trust_remote_code=trust_remote_code,
442+
weight_format=weight_format,
428443
)
429444
self.group_size = group_size or (-1 if bits == 8 else 128)
430445
self.ratio = ratio
431446
self.all_layers = all_layers
432447
self.sensitivity_metric = sensitivity_metric
433448
self.quant_method = OVQuantizationMethod(quant_method) if isinstance(quant_method, str) else quant_method
434449
self.scale_estimation = scale_estimation
435-
self.weight_format = weight_format
436450
self.gptq = gptq
437451
self.lora_correction = lora_correction
438452
self.backup_precision = backup_precision
@@ -578,6 +592,8 @@ def __init__(
578592
processor: Optional[str] = None,
579593
trust_remote_code: bool = False,
580594
smooth_quant_alpha: Optional[float] = None,
595+
weight_format: Optional[str] = "int8",
596+
activation_format: Optional[str] = "int8",
581597
**kwargs,
582598
):
583599
"""
@@ -621,6 +637,10 @@ def __init__(
621637
smooth_quant_alpha (`float`, *optional*):
622638
SmoothQuant alpha parameter that improves the distribution of activations before MatMul layers and
623639
reduces quantization error.
640+
weight_format (`str`, defaults to "int8"):
641+
Data format weights are quantized to. Possible values: ['int8'].
642+
activation_format (`str`, defaults to "int8"):
643+
Data format activations are compressed to. Possible values: ['int8'].
624644
"""
625645
super().__init__(
626646
bits=bits,
@@ -631,11 +651,13 @@ def __init__(
631651
tokenizer=tokenizer,
632652
processor=processor,
633653
trust_remote_code=trust_remote_code,
654+
weight_format=weight_format,
634655
)
635656
self.model_type = model_type
636657
self.fast_bias_correction = fast_bias_correction
637658
self.overflow_fix = overflow_fix
638659
self.smooth_quant_alpha = smooth_quant_alpha
660+
self.activation_format = activation_format
639661
self.post_init()
640662

641663
def post_init(self):
@@ -659,6 +681,12 @@ def post_init(self):
659681
f"SmoothQuant alpha parameter must be in range [0, 1], but found {self.smooth_quant_alpha}"
660682
)
661683

684+
if self.weight_format != "int8":
685+
raise ValueError("Only 'int8' weight format is currently supported.")
686+
687+
if self.activation_format != "int8":
688+
raise ValueError("Only 'int8' activation format is currently supported.")
689+
662690

663691
class OVConfig(BaseConfig):
664692
CONFIG_NAME = "openvino_config.json"

optimum/intel/openvino/quantization.py

+5
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,11 @@ def _quantize_ovbasemodel(
458458
if calibration_dataset is None:
459459
raise ValueError("Calibration dataset is required to run quantization.")
460460

461+
if quantization_config.weight_format != "int8":
462+
raise ValueError("Only 'int8' weight format is currently supported.")
463+
if quantization_config.activation_format != "int8":
464+
raise ValueError("Only 'int8' activation format is currently supported.")
465+
461466
# Quantize model(s)
462467
if isinstance(self.model, _OVModelForWhisper):
463468
self._quantize_whisper_model(quantization_config, calibration_dataset, **kwargs)

optimum/intel/openvino/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
"open_clip_text": "OVModelOpenCLIPText",
132132
"open_clip_vision": "OVModelOpenCLIPVisual",
133133
"open_clip": "OVModelOpenCLIPForZeroShotImageClassification",
134+
"automatic-speech-recognition": "OVModelForSpeechSeq2Seq",
134135
}
135136

136137

tests/openvino/test_exporters_cli.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import subprocess
1515
import unittest
1616
from pathlib import Path
17-
from typing import Dict, List
17+
from typing import Dict, List, Tuple
1818

1919
from parameterized import parameterized
2020
from transformers import AutoModelForCausalLM
@@ -37,6 +37,7 @@
3737
OVModelForQuestionAnswering,
3838
OVModelForSeq2SeqLM,
3939
OVModelForSequenceClassification,
40+
OVModelForSpeechSeq2Seq,
4041
OVModelForTokenClassification,
4142
OVModelForVisualCausalLM,
4243
OVModelOpenCLIPForZeroShotImageClassification,
@@ -109,6 +110,16 @@ class OVCLIExportTestCase(unittest.TestCase):
109110
SUPPORTED_SD_HYBRID_ARCHITECTURES.append(("stable-diffusion-3", 9, 65))
110111
SUPPORTED_SD_HYBRID_ARCHITECTURES.append(("flux", 7, 56))
111112

113+
SUPPORTED_QUANTIZATION_ARCHITECTURES = [
114+
(
115+
"automatic-speech-recognition",
116+
"whisper",
117+
"--quant-mode int8 --dataset librispeech --num-samples 1 --smooth-quant-alpha 0.9 --trust-remote-code",
118+
(14, 22, 21) if is_transformers_version("<=", "4.36.0") else (14, 22, 25),
119+
(14, 21, 17) if is_transformers_version("<=", "4.36.0") else (14, 22, 18),
120+
),
121+
]
122+
112123
TEST_4BIT_CONFIGURATIONS = [
113124
("text-generation-with-past", "opt125m", "int4 --sym --group-size 128", [{"int8": 4, "int4": 72}]),
114125
("text-generation-with-past", "opt125m", "int4 --group-size 64", [{"int8": 4, "int4": 144}]),
@@ -391,6 +402,32 @@ def test_exporters_cli_4bit(
391402
"--lora-correction" not in option or b"with correction of low-rank adapters" in result.stdout
392403
)
393404

405+
@parameterized.expand(SUPPORTED_QUANTIZATION_ARCHITECTURES)
406+
def test_exporters_cli_full_quantization(
407+
self,
408+
task: str,
409+
model_type: str,
410+
option: str,
411+
expected_num_fq_nodes_per_model: Tuple[int],
412+
expected_num_weight_nodes_per_model: Tuple[int],
413+
):
414+
with TemporaryDirectory() as tmpdir:
415+
subprocess.run(
416+
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} {option} {tmpdir}",
417+
shell=True,
418+
check=True,
419+
)
420+
model = eval(_HEAD_TO_AUTOMODELS[task]).from_pretrained(tmpdir)
421+
422+
submodels = []
423+
if task == "automatic-speech-recognition":
424+
submodels = [model.encoder, model.decoder, model.decoder_with_past]
425+
self.assertEqual(len(expected_num_fq_nodes_per_model), len(submodels))
426+
for i, model in enumerate(submodels):
427+
actual_num_fq_nodes, actual_num_weight_nodes = get_num_quantized_nodes(model)
428+
self.assertEqual(expected_num_fq_nodes_per_model[i], actual_num_fq_nodes)
429+
self.assertEqual(expected_num_weight_nodes_per_model[i], actual_num_weight_nodes["int8"])
430+
394431
def test_exporters_cli_int4_with_local_model_and_default_config(self):
395432
with TemporaryDirectory() as tmpdir:
396433
pt_model = AutoModelForCausalLM.from_pretrained(MODEL_NAMES["falcon-40b"])

0 commit comments

Comments
 (0)