Skip to content

Commit 38b6e54

Browse files
authoredJan 20, 2025··
Add FP8 quantization test (#1114)
* Add llm fp8 quantization test; rename test variables; add dataset name check * Create a variable for supported language datasets * ruff
1 parent 2590794 commit 38b6e54

File tree

4 files changed

+90
-48
lines changed

4 files changed

+90
-48
lines changed
 

‎optimum/intel/openvino/configuration.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
from optimum.configuration_utils import BaseConfig
2727

2828
from ..utils.import_utils import is_nncf_available
29-
from .utils import PREDEFINED_SD_DATASETS, PREDEFINED_VISUAL_LM_DATASETS
29+
from .utils import (
30+
LANGUAGE_DATASETS,
31+
PREDEFINED_SD_DATASETS,
32+
PREDEFINED_SPEECH_TO_TEXT_DATASETS,
33+
PREDEFINED_VISUAL_LM_DATASETS,
34+
)
3035

3136

3237
if is_nncf_available():
@@ -467,13 +472,12 @@ def post_init(self):
467472
f"If you wish to provide a custom dataset, please use the `OVQuantizer` instead."
468473
)
469474
if self.dataset is not None and isinstance(self.dataset, str):
470-
lm_datasets = ["wikitext2", "c4", "c4-new", "auto"]
471475
visual_lm_datasets = list(PREDEFINED_VISUAL_LM_DATASETS.keys())
472476
stable_diffusion_datasets = list(PREDEFINED_SD_DATASETS.keys())
473-
if self.dataset not in lm_datasets + visual_lm_datasets + stable_diffusion_datasets:
477+
if self.dataset not in LANGUAGE_DATASETS + visual_lm_datasets + stable_diffusion_datasets:
474478
raise ValueError(
475479
f"""You have entered a string value for dataset. You can only choose between
476-
{lm_datasets} for LLMs, {visual_lm_datasets} for visual LLMs
480+
{LANGUAGE_DATASETS} for LLMs, {visual_lm_datasets} for visual LLMs
477481
or {stable_diffusion_datasets} for diffusion models, but we found {self.dataset}"""
478482
)
479483

@@ -617,7 +621,8 @@ def __init__(
617621
overflow_fix (`str`, default to "disable"):
618622
Parameter for controlling overflow fix setting.
619623
dataset (`str`, *optional*):
620-
The dataset used for quantization. For text-to-speech model quantization the allowed value is 'librispeech'.
624+
The dataset used for quantization. For language models the allowed values are
625+
['auto', 'wikitext2','c4','c4-new']. For text-to-speech model quantization the allowed value is 'librispeech'.
621626
tokenizer (`str`, *optional*):
622627
The tokenizer used to process the dataset. You can pass either:
623628
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
@@ -673,6 +678,14 @@ def post_init(self):
673678
"""
674679
super().post_init()
675680

681+
if self.dataset is not None:
682+
speech_to_text_datasets = list(PREDEFINED_SPEECH_TO_TEXT_DATASETS.keys())
683+
if self.dataset not in LANGUAGE_DATASETS + speech_to_text_datasets:
684+
raise ValueError(
685+
f"""You can only choose between the following datasets: {LANGUAGE_DATASETS} for LLMs or
686+
{speech_to_text_datasets} for speech-to-text models, but we found {self.dataset}."""
687+
)
688+
676689
if self.bits != 8:
677690
raise ValueError(f"Only support 8-bit for static quantization but found {self.bits}")
678691

‎optimum/intel/openvino/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@
136136
}
137137

138138

139+
LANGUAGE_DATASETS = ["wikitext2", "c4", "c4-new", "auto"]
140+
139141
PREDEFINED_SD_DATASETS = {
140142
"conceptual_captions": {"split": "train", "inputs": {"prompt": "caption"}},
141143
"laion/220k-GPT4Vision-captions-from-LIVIS": {"split": "train", "inputs": {"prompt": "caption"}},

‎tests/openvino/test_exporters_cli.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -365,19 +365,21 @@ def test_exporters_cli_int8(self, task: str, model_type: str):
365365
self.assertEqual(expected_int8[i], num_weight_nodes["int8"])
366366

367367
@parameterized.expand(SUPPORTED_SD_HYBRID_ARCHITECTURES)
368-
def test_exporters_cli_hybrid_quantization(self, model_type: str, exp_num_fq: int, exp_num_int8: int):
368+
def test_exporters_cli_hybrid_quantization(
369+
self, model_type: str, expected_fake_nodes: int, expected_int8_nodes: int
370+
):
369371
with TemporaryDirectory() as tmpdir:
370372
subprocess.run(
371373
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --dataset laion/filtered-wit --weight-format int8 {tmpdir}",
372374
shell=True,
373375
check=True,
374376
)
375377
model = eval(_HEAD_TO_AUTOMODELS[model_type.replace("-refiner", "")]).from_pretrained(tmpdir)
376-
num_fq, num_weight_nodes = get_num_quantized_nodes(
378+
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(
377379
model.unet if model.unet is not None else model.transformer
378380
)
379-
self.assertEqual(exp_num_int8, num_weight_nodes["int8"])
380-
self.assertEqual(exp_num_fq, num_fq)
381+
self.assertEqual(expected_int8_nodes, num_weight_nodes["int8"])
382+
self.assertEqual(expected_fake_nodes, num_fake_nodes)
381383

382384
@parameterized.expand(TEST_4BIT_CONFIGURATIONS)
383385
def test_exporters_cli_4bit(
@@ -422,8 +424,8 @@ def test_exporters_cli_full_quantization(
422424
model_type: str,
423425
quant_mode: str,
424426
option: str,
425-
expected_num_f_nodes_per_model: Tuple[int],
426-
expected_num_weight_nodes_per_model: Tuple[int],
427+
expected_fake_nodes: Tuple[int],
428+
expected_low_precision_nodes: Tuple[int],
427429
):
428430
with TemporaryDirectory() as tmpdir:
429431
subprocess.run(
@@ -439,12 +441,12 @@ def test_exporters_cli_full_quantization(
439441
if model.decoder_with_past is not None:
440442
models.append(model.decoder_with_past)
441443
else:
442-
expected_num_f_nodes_per_model = expected_num_f_nodes_per_model[:-1]
443-
self.assertEqual(len(expected_num_f_nodes_per_model), len(models))
444+
expected_fake_nodes = expected_fake_nodes[:-1]
445+
self.assertEqual(len(expected_fake_nodes), len(models))
444446
for i, model in enumerate(models):
445-
actual_num_f_nodes, actual_num_weight_nodes = get_num_quantized_nodes(model)
446-
self.assertEqual(expected_num_f_nodes_per_model[i], actual_num_f_nodes)
447-
self.assertEqual(expected_num_weight_nodes_per_model[i], actual_num_weight_nodes[quant_mode])
447+
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(model)
448+
self.assertEqual(expected_fake_nodes[i], num_fake_nodes)
449+
self.assertEqual(expected_low_precision_nodes[i], num_weight_nodes[quant_mode])
448450

449451
def test_exporters_cli_int4_with_local_model_and_default_config(self):
450452
with TemporaryDirectory() as tmpdir:

‎tests/openvino/test_quantization.py

+57-32
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,23 @@ class OVQuantizerTest(unittest.TestCase):
114114
(14, 22, 21) if is_transformers_version("<=", "4.42.4") else (14, 22, 25),
115115
(14, 21, 17) if is_transformers_version("<=", "4.42.4") else (14, 22, 18),
116116
),
117+
(
118+
OVModelForCausalLM,
119+
"llama",
120+
OVQuantizationConfig(
121+
dataset="wikitext2",
122+
num_samples=1,
123+
weight_only=False,
124+
weight_format="f8e4m3",
125+
activation_format="f8e4m3",
126+
),
127+
(13,),
128+
(16,),
129+
),
117130
]
118131

119132
@parameterized.expand(SUPPORTED_ARCHITECTURES_TORCH_MODEL)
120-
def test_automodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8):
133+
def test_automodel_static_quantization(self, model_cls, model_name, expected_fake_nodes, expected_int8_nodes):
121134
model_id = MODEL_NAMES[model_name]
122135
task = model_cls.export_feature
123136
dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task]
@@ -149,9 +162,9 @@ def preprocess_function(examples, tokenizer):
149162
ov_config=ov_config,
150163
)
151164
model = model_cls.from_pretrained(tmp_dir, file_name=file_name)
152-
num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(model)
153-
self.assertEqual(expected_fake_quantize, num_fake_quantize)
154-
self.assertEqual(expected_int8, num_weight_nodes["int8"])
165+
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(model)
166+
self.assertEqual(expected_fake_nodes, num_fake_nodes)
167+
self.assertEqual(expected_int8_nodes, num_weight_nodes["int8"])
155168

156169
tokens = tokenizer("This is a sample input", return_tensors="pt")
157170
outputs = model(**tokens)
@@ -162,7 +175,7 @@ def preprocess_function(examples, tokenizer):
162175
self.assertEqual(ov_config.quantization_config.to_dict(), loaded_config.quantization_config.to_dict())
163176

164177
@parameterized.expand(SUPPORTED_ARCHITECTURES_OV_MODEL)
165-
def test_ovmodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8):
178+
def test_ovmodel_static_quantization(self, model_cls, model_name, expected_fake_nodes, expected_int8_nodes):
166179
model_id = MODEL_NAMES[model_name]
167180
task = model_cls.export_feature
168181
dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task]
@@ -190,9 +203,9 @@ def preprocess_function(examples, tokenizer):
190203

191204
model = model_cls.from_pretrained(tmp_dir)
192205

193-
num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(model)
194-
self.assertEqual(expected_fake_quantize, num_fake_quantize)
195-
self.assertEqual(expected_int8, num_weight_nodes["int8"])
206+
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(model)
207+
self.assertEqual(expected_fake_nodes, num_fake_nodes)
208+
self.assertEqual(expected_int8_nodes, num_weight_nodes["int8"])
196209

197210
tokens = tokenizer("This is a sample input", return_tensors="pt")
198211
outputs = model(**tokens)
@@ -204,9 +217,10 @@ def preprocess_function(examples, tokenizer):
204217

205218
@parameterized.expand(SUPPORTED_ARCHITECTURES_OV_MODEL_WITH_AUTO_DATASET)
206219
def test_ov_model_static_quantization_with_auto_dataset(
207-
self, model_cls, model_name, quantization_config, expected_fake_quantize, expected_int8
220+
self, model_cls, model_name, quantization_config, expected_fake_nodes, expected_low_precision_nodes
208221
):
209222
model_id = MODEL_NAMES[model_name]
223+
quant_mode = quantization_config.activation_format
210224

211225
with TemporaryDirectory() as tmp_dir:
212226
ov_model = model_cls.from_pretrained(model_id, quantization_config=quantization_config)
@@ -217,17 +231,28 @@ def test_ov_model_static_quantization_with_auto_dataset(
217231

218232
if ov_model.decoder_with_past is not None:
219233
models.append(ov_model.decoder_with_past.model)
220-
for model, expected_fq, expected_i8 in zip(
234+
for model, expected_fake_nodes, expected_lp_nodes in zip(
221235
models,
222-
expected_fake_quantize,
223-
expected_int8,
236+
expected_fake_nodes,
237+
expected_low_precision_nodes,
224238
):
225-
num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(model)
226-
self.assertEqual(expected_fq, num_fake_quantize)
227-
self.assertEqual(expected_i8, num_weight_nodes["int8"])
239+
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(model)
240+
self.assertEqual(expected_fake_nodes, num_fake_nodes)
241+
self.assertEqual(expected_lp_nodes, num_weight_nodes[quant_mode])
228242

229243
input_features = torch.randn((1, 128, 3000), dtype=torch.float32)
230244
ov_model.generate(input_features)
245+
elif model_cls == OVModelForCausalLM:
246+
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(ov_model.model)
247+
self.assertEqual(expected_fake_nodes[0], num_fake_nodes)
248+
self.assertEqual(expected_low_precision_nodes[0], num_weight_nodes[quant_mode])
249+
250+
tokenizer = AutoTokenizer.from_pretrained(model_id)
251+
if tokenizer.pad_token is None:
252+
tokenizer.pad_token = tokenizer.eos_token
253+
tokens = tokenizer("This is a sample input", return_tensors="pt")
254+
outputs = ov_model(**tokens)
255+
self.assertTrue("logits" in outputs)
231256
else:
232257
raise Exception("Unexpected model class.")
233258

@@ -608,7 +633,7 @@ def test_ovmodel_8bit_weight_compression(self, model_cls, model_name, expected_p
608633
self.assertEqual(OVWeightQuantizationConfig().to_dict(), loaded_config.quantization_config.to_dict())
609634

610635
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS)
611-
def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_int8, expected_int4):
636+
def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_int8_nodes, expected_int4_nodes):
612637
task = model_cls.export_feature
613638
model_id = MODEL_NAMES[model_name]
614639
with TemporaryDirectory() as tmp_dir:
@@ -623,8 +648,8 @@ def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_i
623648
model = model_cls.from_pretrained(tmp_dir)
624649

625650
_, num_weight_nodes = get_num_quantized_nodes(model)
626-
self.assertEqual(expected_int8, num_weight_nodes["int8"])
627-
self.assertEqual(expected_int4, num_weight_nodes["int4"])
651+
self.assertEqual(expected_int8_nodes, num_weight_nodes["int8"])
652+
self.assertEqual(expected_int4_nodes, num_weight_nodes["int4"])
628653

629654
tokens = tokenizer("This is a sample input", return_tensors="pt")
630655
outputs = model(**tokens)
@@ -699,17 +724,17 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type, trust
699724
self.assertEqual(expected_ov_int8[i], num_weight_nodes["int8"])
700725

701726
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION)
702-
def test_ovmodel_hybrid_quantization(self, model_cls, model_type, expected_num_fake_quantize, expected_ov_int8):
727+
def test_ovmodel_hybrid_quantization(self, model_cls, model_type, expected_fake_nodes, expected_int8_nodes):
703728
model_id = MODEL_NAMES[model_type]
704729
quantization_config = OVWeightQuantizationConfig(bits=8, dataset="conceptual_captions", num_samples=2)
705730
with TemporaryDirectory() as tmp_dir:
706731
model = model_cls.from_pretrained(model_id, export=True, quantization_config=quantization_config)
707732

708-
num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(
733+
num_fake, num_weight_nodes = get_num_quantized_nodes(
709734
model.unet if model.unet is not None else model.transformer
710735
)
711-
self.assertEqual(expected_num_fake_quantize, num_fake_quantize)
712-
self.assertEqual(expected_ov_int8, num_weight_nodes["int8"])
736+
self.assertEqual(expected_fake_nodes, num_fake)
737+
self.assertEqual(expected_int8_nodes, num_weight_nodes["int8"])
713738
self.assertEqual(0, num_weight_nodes["int4"])
714739

715740
model.save_pretrained(tmp_dir)
@@ -721,16 +746,16 @@ def test_stable_diffusion_with_weight_compression(self):
721746

722747
quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config))
723748

724-
num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(
749+
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(
725750
int8_pipe.unet if int8_pipe.unet is not None else int8_pipe.transformer
726751
)
727-
self.assertEqual(0, num_fake_quantize)
752+
self.assertEqual(0, num_fake_nodes)
728753
self.assertEqual(242, num_weight_nodes["int8"])
729754
self.assertEqual(0, num_weight_nodes["int4"])
730755

731756
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION[-1:])
732757
def test_ovmodel_hybrid_quantization_with_custom_dataset(
733-
self, model_cls, model_type, expected_num_fake_quantize, expected_ov_int8
758+
self, model_cls, model_type, expected_fake_nodes, expected_int8_nodes
734759
):
735760
model_id = MODEL_NAMES[model_type]
736761
dataset = [
@@ -742,11 +767,11 @@ def test_ovmodel_hybrid_quantization_with_custom_dataset(
742767
self.assertEqual(quantization_config.quant_method, OVQuantizationMethod.HYBRID)
743768

744769
quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config), calibration_dataset=dataset)
745-
num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(
770+
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(
746771
model.unet if model.unet is not None else model.transformer
747772
)
748-
self.assertEqual(expected_num_fake_quantize, num_fake_quantize)
749-
self.assertEqual(expected_ov_int8, num_weight_nodes["int8"])
773+
self.assertEqual(expected_fake_nodes, num_fake_nodes)
774+
self.assertEqual(expected_int8_nodes, num_weight_nodes["int8"])
750775
self.assertEqual(0, num_weight_nodes["int4"])
751776

752777
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTOCOMPRESSED_MATMULS)
@@ -1050,7 +1075,7 @@ class OVTrainerTest(unittest.TestCase):
10501075
@unittest.skipIf(
10511076
is_transformers_version(">=", "4.46"), reason="OVTrainer is not compatible with transformers>=v4.46"
10521077
)
1053-
def test_aware_training_quantization(self, model_name, expected_fake_quantize, expected_int8):
1078+
def test_aware_training_quantization(self, model_name, expected_fake_nodes, expected_int8_nodes):
10541079
model_id = MODEL_NAMES[model_name]
10551080
model = AutoModelForSequenceClassification.from_pretrained(model_id, attn_implementation="eager")
10561081
tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -1084,9 +1109,9 @@ def compute_metrics(p):
10841109
trainer.save_model()
10851110

10861111
model = OVModelForSequenceClassification.from_pretrained(tmp_dir)
1087-
num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(model)
1088-
self.assertEqual(expected_fake_quantize, num_fake_quantize)
1089-
self.assertEqual(expected_int8, num_weight_nodes["int8"])
1112+
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(model)
1113+
self.assertEqual(expected_fake_nodes, num_fake_nodes)
1114+
self.assertEqual(expected_int8_nodes, num_weight_nodes["int8"])
10901115

10911116
tokens = tokenizer("This is a sample input", return_tensors="pt")
10921117
outputs = model(**tokens)

0 commit comments

Comments
 (0)
Please sign in to comment.