Skip to content

Commit 878b474

Browse files
Fp8 implementation (#1100)
* Fp8 implementation * All datasets support * Added test * Update test * Correctness * Correctness * Update docs/source/openvino/export.mdx Co-authored-by: Alexander Kozlov <alexander.kozlov@intel.com> * Change test model * Apply comments --------- Co-authored-by: Alexander Kozlov <alexander.kozlov@intel.com>
1 parent feaf027 commit 878b474

File tree

6 files changed

+65
-66
lines changed

6 files changed

+65
-66
lines changed

docs/source/openvino/export.mdx

+4-5
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Check out the help for more options:
3131

3232
```text
3333
usage: optimum-cli export openvino [-h] -m MODEL [--task TASK] [--framework {pt,tf}] [--trust-remote-code]
34-
[--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}] [--quant-mode {int8}]
34+
[--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}] [--quant-mode {int8,f8e4m3,f8e5m2}]
3535
[--library {transformers,diffusers,timm,sentence_transformers,open_clip}]
3636
[--cache_dir CACHE_DIR] [--pad-token-id PAD_TOKEN_ID] [--ratio RATIO] [--sym]
3737
[--group-size GROUP_SIZE] [--backup-precision {none,int8_sym,int8_asym}]
@@ -67,10 +67,9 @@ Optional arguments:
6767
on your local machine arbitrary code present in the model repository.
6868
--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}
6969
The weight format of the exported model.
70-
--quant-mode {int8}
70+
--quant-mode {int8,f8e4m3,f8e5m2}
7171
Quantization precision mode. This is used for applying full model quantization including
72-
activations. The only currently supported choice is 'int8' for int8 quantization of both
73-
weights and activations.
72+
activations.
7473
--library {transformers,diffusers,timm,sentence_transformers,open_clip}
7574
The library used to load the model before export. If not provided, will attempt to infer the
7675
local checkpoint's library
@@ -166,7 +165,7 @@ Models larger than 1 billion parameters are exported to the OpenVINO format with
166165
</Tip>
167166

168167

169-
Besides weight-only quantization, you can also apply full model quantization including activations by setting `--quant-mode` to `int8`. This will quantize both weights and activations of Linear, Convolutional and some other layers to int8. Currently this is only supported for speech-to-text models. Please see example below.
168+
Besides weight-only quantization, you can also apply full model quantization including activations by setting `--quant-mode` to preffered precision. This will quantize both weights and activations of Linear, Convolutional and some other layers to selected mode. Please see example below.
170169

171170
```bash
172171
optimum-cli export openvino -m openai/whisper-large-v3-turbo --quant-mode int8 --dataset librispeech --num-samples 32 --smooth-quant-alpha 0.9 ./whisper-large-v3-turbo

optimum/commands/export/openvino.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,10 @@ def parse_args_openvino(parser: "ArgumentParser"):
7878
optional_group.add_argument(
7979
"--quant-mode",
8080
type=str,
81-
choices=["int8"],
81+
choices=["int8", "f8e4m3", "f8e5m2"],
8282
default=None,
8383
help=(
8484
"Quantization precision mode. This is used for applying full model quantization including activations. "
85-
"The only currently supported choice is 'int8' for int8 quantization of both weights and activations."
8685
),
8786
)
8887
optional_group.add_argument(
@@ -365,9 +364,6 @@ def run(self):
365364
quantization_config["trust_remote_code"] = self.args.trust_remote_code
366365
ov_config = OVConfig(quantization_config=quantization_config)
367366
else:
368-
if self.args.quant_mode != "int8":
369-
raise ValueError("Only 'int8' quantization mode is currently supported.")
370-
371367
quantization_config = {
372368
"weight_format": self.args.quant_mode,
373369
"activation_format": self.args.quant_mode,

optimum/intel/openvino/configuration.py

+11-20
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from optimum.configuration_utils import BaseConfig
2727

2828
from ..utils.import_utils import is_nncf_available
29-
from .utils import PREDEFINED_SD_DATASETS, PREDEFINED_SPEECH_TO_TEXT_DATASETS, PREDEFINED_VISUAL_LM_DATASETS
29+
from .utils import PREDEFINED_SD_DATASETS, PREDEFINED_VISUAL_LM_DATASETS
3030

3131

3232
if is_nncf_available():
@@ -638,9 +638,9 @@ def __init__(
638638
SmoothQuant alpha parameter that improves the distribution of activations before MatMul layers and
639639
reduces quantization error.
640640
weight_format (`str`, defaults to "int8"):
641-
Data format weights are quantized to. Possible values: ['int8'].
641+
Data format weights are quantized to. Possible values: ['int8', 'f8e4m3', 'f8e5m2'].
642642
activation_format (`str`, defaults to "int8"):
643-
Data format activations are compressed to. Possible values: ['int8'].
643+
Data format activations are compressed to. Possible values: ['int8', 'f8e4m3', 'f8e5m2'].
644644
"""
645645
super().__init__(
646646
bits=bits,
@@ -658,6 +658,13 @@ def __init__(
658658
self.overflow_fix = overflow_fix
659659
self.smooth_quant_alpha = smooth_quant_alpha
660660
self.activation_format = activation_format
661+
662+
f8_formats = ["f8e4m3", "f8e5m2"]
663+
if self.activation_format in f8_formats and self.weight_format in f8_formats:
664+
logger.info(
665+
f"{self.activation_format} for activations and {self.weight_format} weights were found. A symmetrical scheme will be used."
666+
)
667+
self.sym = True
661668
self.post_init()
662669

663670
def post_init(self):
@@ -669,24 +676,11 @@ def post_init(self):
669676
if self.bits != 8:
670677
raise ValueError(f"Only support 8-bit for static quantization but found {self.bits}")
671678

672-
if self.dataset is not None:
673-
if self.dataset not in PREDEFINED_SPEECH_TO_TEXT_DATASETS:
674-
raise ValueError(
675-
f"You have entered the following string value for dataset: {self.dataset}. But it is not supported."
676-
f" Currently you can only choose {list(PREDEFINED_SPEECH_TO_TEXT_DATASETS.keys())}."
677-
)
678-
679679
if self.smooth_quant_alpha is not None and not (0 <= self.smooth_quant_alpha <= 1):
680680
raise ValueError(
681681
f"SmoothQuant alpha parameter must be in range [0, 1], but found {self.smooth_quant_alpha}"
682682
)
683683

684-
if self.weight_format != "int8":
685-
raise ValueError("Only 'int8' weight format is currently supported.")
686-
687-
if self.activation_format != "int8":
688-
raise ValueError("Only 'int8' activation format is currently supported.")
689-
690684

691685
class OVConfig(BaseConfig):
692686
CONFIG_NAME = "openvino_config.json"
@@ -711,10 +705,7 @@ def __init__(
711705
"compression", None
712706
) # A field for backward-compatability of training-time compression parameters
713707
if self.quantization_config is not None:
714-
if isinstance(self.quantization_config, OVWeightQuantizationConfig):
715-
self.dtype = self.quantization_config.weight_format
716-
else:
717-
self.dtype = "int8"
708+
self.dtype = self.quantization_config.weight_format
718709
else:
719710
self.dtype = dtype
720711

optimum/intel/openvino/quantization.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -458,11 +458,6 @@ def _quantize_ovbasemodel(
458458
if calibration_dataset is None:
459459
raise ValueError("Calibration dataset is required to run quantization.")
460460

461-
if quantization_config.weight_format != "int8":
462-
raise ValueError("Only 'int8' weight format is currently supported.")
463-
if quantization_config.activation_format != "int8":
464-
raise ValueError("Only 'int8' activation format is currently supported.")
465-
466461
# Quantize model(s)
467462
if isinstance(self.model, _OVModelForWhisper):
468463
self._quantize_whisper_model(quantization_config, calibration_dataset, **kwargs)
@@ -1077,6 +1072,14 @@ def _full_quantization(
10771072
matmul=quantization_config.smooth_quant_alpha
10781073
)
10791074

1075+
q_mode_map = {
1076+
"f8e4m3": nncf.QuantizationMode.FP8_E4M3,
1077+
"f8e5m2": nncf.QuantizationMode.FP8_E5M2,
1078+
}
1079+
1080+
if quantization_config.activation_format in q_mode_map:
1081+
kwargs.update({"mode": q_mode_map[quantization_config.activation_format]})
1082+
10801083
quantized_model = nncf.quantize(
10811084
model,
10821085
calibration_dataset,

tests/openvino/test_exporters_cli.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,19 @@ class OVCLIExportTestCase(unittest.TestCase):
118118
(
119119
"automatic-speech-recognition",
120120
"whisper",
121-
"--quant-mode int8 --dataset librispeech --num-samples 1 --smooth-quant-alpha 0.9 --trust-remote-code",
121+
"int8",
122+
"--dataset librispeech --num-samples 1 --smooth-quant-alpha 0.9 --trust-remote-code",
122123
(14, 22, 21) if is_transformers_version("<=", "4.36.0") else (14, 22, 25),
123124
(14, 21, 17) if is_transformers_version("<=", "4.36.0") else (14, 22, 18),
124125
),
126+
(
127+
"text-generation",
128+
"llama",
129+
"f8e4m3",
130+
"--dataset wikitext2 --num-samples 1 --smooth-quant-alpha 0.9 --trust-remote-code",
131+
(13,),
132+
(16,),
133+
),
125134
]
126135

127136
TEST_4BIT_CONFIGURATIONS = [
@@ -411,30 +420,31 @@ def test_exporters_cli_full_quantization(
411420
self,
412421
task: str,
413422
model_type: str,
423+
quant_mode: str,
414424
option: str,
415-
expected_num_fq_nodes_per_model: Tuple[int],
425+
expected_num_f_nodes_per_model: Tuple[int],
416426
expected_num_weight_nodes_per_model: Tuple[int],
417427
):
418428
with TemporaryDirectory() as tmpdir:
419429
subprocess.run(
420-
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} {option} {tmpdir}",
430+
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --quant-mode {quant_mode} {option} {tmpdir}",
421431
shell=True,
422432
check=True,
423433
)
424434
model = eval(_HEAD_TO_AUTOMODELS[task]).from_pretrained(tmpdir)
425435

426-
submodels = []
436+
models = [model]
427437
if task == "automatic-speech-recognition":
428-
submodels = [model.encoder, model.decoder]
438+
models = [model.encoder, model.decoder]
429439
if model.decoder_with_past is not None:
430-
submodels.append(model.decoder_with_past)
440+
models.append(model.decoder_with_past)
431441
else:
432-
expected_num_fq_nodes_per_model = expected_num_fq_nodes_per_model[:-1]
433-
self.assertEqual(len(expected_num_fq_nodes_per_model), len(submodels))
434-
for i, model in enumerate(submodels):
435-
actual_num_fq_nodes, actual_num_weight_nodes = get_num_quantized_nodes(model)
436-
self.assertEqual(expected_num_fq_nodes_per_model[i], actual_num_fq_nodes)
437-
self.assertEqual(expected_num_weight_nodes_per_model[i], actual_num_weight_nodes["int8"])
442+
expected_num_f_nodes_per_model = expected_num_f_nodes_per_model[:-1]
443+
self.assertEqual(len(expected_num_f_nodes_per_model), len(models))
444+
for i, model in enumerate(models):
445+
actual_num_f_nodes, actual_num_weight_nodes = get_num_quantized_nodes(model)
446+
self.assertEqual(expected_num_f_nodes_per_model[i], actual_num_f_nodes)
447+
self.assertEqual(expected_num_weight_nodes_per_model[i], actual_num_weight_nodes[quant_mode])
438448

439449
def test_exporters_cli_int4_with_local_model_and_default_config(self):
440450
with TemporaryDirectory() as tmpdir:

tests/openvino/utils_tests.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -206,31 +206,31 @@
206206

207207

208208
def get_num_quantized_nodes(model):
209-
num_fake_quantize = 0
210-
num_weight_nodes = {
211-
"int8": 0,
212-
"int4": 0,
213-
"f4e2m1": 0,
214-
"f8e8m0": 0,
215-
"nf4": 0,
209+
num_fake_nodes = 0
210+
types_map = {
211+
"i8": "int8",
212+
"u8": "int8",
213+
"i4": "int4",
214+
"u4": "int4",
215+
"f4e2m1": "f4e2m1",
216+
"f8e8m0": "f8e8m0",
217+
"nf4": "nf4",
218+
"f8e4m3": "f8e4m3",
219+
"f8e5m2": "f8e5m2",
216220
}
221+
num_weight_nodes = {n: 0 for n in types_map.values()}
217222
ov_model = model if isinstance(model, ov.Model) else model.model
218223
for elem in ov_model.get_ops():
219224
if "FakeQuantize" in elem.name:
220-
num_fake_quantize += 1
225+
num_fake_nodes += 1
226+
if "FakeConvert" in elem.name:
227+
num_fake_nodes += 1
221228
for i in range(elem.get_output_size()):
222229
type_name = elem.get_output_element_type(i).get_type_name()
223-
if type_name in ["i8", "u8"]:
224-
num_weight_nodes["int8"] += 1
225-
if type_name in ["i4", "u4"]:
226-
num_weight_nodes["int4"] += 1
227-
if type_name == "f4e2m1":
228-
num_weight_nodes["f4e2m1"] += 1
229-
if type_name == "f8e8m0":
230-
num_weight_nodes["f8e8m0"] += 1
231-
if type_name == "nf4":
232-
num_weight_nodes["nf4"] += 1
233-
return num_fake_quantize, num_weight_nodes
230+
if type_name in types_map:
231+
name = types_map[type_name]
232+
num_weight_nodes[name] += 1
233+
return num_fake_nodes, num_weight_nodes
234234

235235

236236
@contextmanager

0 commit comments

Comments
 (0)