Skip to content

Commit 0420051

Browse files
committed
Raise an error when OVQuantizer is invoked on an already compressed model
1 parent 2590794 commit 0420051

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

optimum/intel/openvino/quantization.py

+42-2
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,31 @@ def quantize(
315315
else:
316316
raise TypeError(f"Unsupported model type: {type(self.model)}")
317317

318+
def _check_model_state(self, sub_model_names: List[str] = None):
319+
message_template = (
320+
"Couldn't apply optimization to the model because it was already compressed with config: {}. "
321+
"To avoid this issue, set load_in_8bit=False in the from_pretrained method when using the optimum-intel API, "
322+
"or explicitly specify the desired weight format using --weight_format fp16/fp32 for CLI."
323+
)
324+
325+
def check_rt_info(ov_model):
326+
rt_info = ov_model.get_rt_info()
327+
if "nncf" in rt_info:
328+
model_weight_compression_config = rt_info["nncf"].get("weight_compression", None)
329+
model_quantization_config = rt_info["nncf"].get("quantization", None)
330+
if model_weight_compression_config is not None:
331+
raise RuntimeError(message_template.format(model_weight_compression_config))
332+
elif model_quantization_config is not None:
333+
raise RuntimeError(message_template.format(model_quantization_config))
334+
335+
if sub_model_names is None:
336+
check_rt_info(self.model.model)
337+
else:
338+
for name in sub_model_names:
339+
if hasattr(self.model, name):
340+
ov_model = getattr(self.model, name).model
341+
check_rt_info(ov_model)
342+
318343
def _quantize_ovbasemodel(
319344
self,
320345
ov_config: OVConfig,
@@ -325,7 +350,7 @@ def _quantize_ovbasemodel(
325350
remove_unused_columns: bool = True,
326351
**kwargs,
327352
):
328-
from optimum.intel.openvino.modeling_seq2seq import _OVModelForWhisper
353+
from optimum.intel.openvino.modeling_seq2seq import _OVModelForWhisper, OVModelForSeq2SeqLM
329354
from optimum.intel.openvino.modeling_visual_language import OVModelForVisualCausalLM
330355

331356
if is_diffusers_available():
@@ -404,6 +429,7 @@ def _quantize_ovbasemodel(
404429
"text_encoder_2",
405430
"text_encoder_3",
406431
]
432+
self._check_model_state(sub_model_names)
407433
sub_models = filter(lambda x: x, (getattr(self.model, name) for name in sub_model_names))
408434
for sub_model in sub_models:
409435
_weight_only_quantization(sub_model.model, quantization_config_copy, **kwargs)
@@ -421,6 +447,7 @@ def _quantize_ovbasemodel(
421447
self.model.clear_requests()
422448
else:
423449
# The model may be for example OVModelForImageClassification, OVModelForAudioClassification, etc.
450+
self._check_model_state()
424451
self.model.model = _hybrid_quantization(
425452
self.model.model, quantization_config, calibration_dataset, **kwargs
426453
)
@@ -436,19 +463,31 @@ def _quantize_ovbasemodel(
436463
"transformer",
437464
"text_encoder_3",
438465
]
466+
self._check_model_state(sub_model_names)
439467
sub_models = filter(lambda x: x, (getattr(self.model, name) for name in sub_model_names))
440468
for sub_model in sub_models:
441469
_weight_only_quantization(sub_model.model, quantization_config, **kwargs)
442470
self.model.clear_requests()
443471
elif isinstance(self.model, OVModelForVisualCausalLM):
444472
language_model = self.model.language_model
445-
_weight_only_quantization(language_model.model, quantization_config, calibration_dataset, **kwargs)
446473
sub_model_names = ["vision_embeddings", "text_embeddings"] + self.model.additional_parts
474+
self._check_model_state(sub_model_names + ["language_model"])
475+
_weight_only_quantization(language_model.model, quantization_config, calibration_dataset, **kwargs)
447476
sub_models = [getattr(self.model, f"{name}_model") for name in sub_model_names]
448477
for sub_model in sub_models:
449478
_weight_only_quantization(sub_model, OVWeightQuantizationConfig(bits=8, sym=True), **kwargs)
450479
self.model.clear_requests()
480+
elif isinstance(self.model, OVModelForSeq2SeqLM):
481+
sub_model_names = ["encoder", "decoder"]
482+
if self.model.decoder_with_past is not None:
483+
sub_model_names.append("decoder_with_past")
484+
self._check_model_state(sub_model_names)
485+
sub_models = [getattr(self.model, name) for name in sub_model_names]
486+
for sub_model in sub_models:
487+
_weight_only_quantization(sub_model, quantization_config, **kwargs)
488+
self.model.clear_requests()
451489
else:
490+
self._check_model_state()
452491
_weight_only_quantization(self.model.model, quantization_config, calibration_dataset, **kwargs)
453492
self.model.request = None
454493
else:
@@ -460,6 +499,7 @@ def _quantize_ovbasemodel(
460499

461500
# Quantize model(s)
462501
if isinstance(self.model, _OVModelForWhisper):
502+
self._check_model_state(["encoder_model", "decoder_model", "decoder_with_past_model"])
463503
self._quantize_whisper_model(quantization_config, calibration_dataset, **kwargs)
464504
else:
465505
quantized_model = _full_quantization(

tests/openvino/test_quantization.py

+17
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,23 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type, trust
698698
_, num_weight_nodes = get_num_quantized_nodes(model)
699699
self.assertEqual(expected_ov_int8[i], num_weight_nodes["int8"])
700700

701+
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION)
702+
def test_raise_error_WC_over_WC(self, model_cls, model_type, trust_remote_code):
703+
model = model_cls.from_pretrained(
704+
MODEL_NAMES[model_type],
705+
export=True,
706+
load_in_8bit=True,
707+
trust_remote_code=trust_remote_code,
708+
)
709+
quantization_config = OVWeightQuantizationConfig(bits=4, sym=True)
710+
quantizer = OVQuantizer(model)
711+
if isinstance(model, OVModelOpenCLIPForZeroShotImageClassification):
712+
with pytest.raises(TypeError):
713+
quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config))
714+
else:
715+
with pytest.raises(RuntimeError):
716+
quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config))
717+
701718
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION)
702719
def test_ovmodel_hybrid_quantization(self, model_cls, model_type, expected_num_fake_quantize, expected_ov_int8):
703720
model_id = MODEL_NAMES[model_type]

0 commit comments

Comments
 (0)