@@ -315,6 +315,31 @@ def quantize(
315
315
else :
316
316
raise TypeError (f"Unsupported model type: { type (self .model )} " )
317
317
318
+ def _check_model_state (self , sub_model_names : List [str ] = None ):
319
+ message_template = (
320
+ "Couldn't apply optimization to the model because it was already compressed with config: {}. "
321
+ "To avoid this issue, set load_in_8bit=False in the from_pretrained method when using the optimum-intel API, "
322
+ "or explicitly specify the desired weight format using --weight_format fp16/fp32 for CLI."
323
+ )
324
+
325
+ def check_rt_info (ov_model ):
326
+ rt_info = ov_model .get_rt_info ()
327
+ if "nncf" in rt_info :
328
+ model_weight_compression_config = rt_info ["nncf" ].get ("weight_compression" , None )
329
+ model_quantization_config = rt_info ["nncf" ].get ("quantization" , None )
330
+ if model_weight_compression_config is not None :
331
+ raise RuntimeError (message_template .format (model_weight_compression_config ))
332
+ elif model_quantization_config is not None :
333
+ raise RuntimeError (message_template .format (model_quantization_config ))
334
+
335
+ if sub_model_names is None :
336
+ check_rt_info (self .model .model )
337
+ else :
338
+ for name in sub_model_names :
339
+ if hasattr (self .model , name ):
340
+ ov_model = getattr (self .model , name ).model
341
+ check_rt_info (ov_model )
342
+
318
343
def _quantize_ovbasemodel (
319
344
self ,
320
345
ov_config : OVConfig ,
@@ -325,7 +350,7 @@ def _quantize_ovbasemodel(
325
350
remove_unused_columns : bool = True ,
326
351
** kwargs ,
327
352
):
328
- from optimum .intel .openvino .modeling_seq2seq import _OVModelForWhisper
353
+ from optimum .intel .openvino .modeling_seq2seq import _OVModelForWhisper , OVModelForSeq2SeqLM
329
354
from optimum .intel .openvino .modeling_visual_language import OVModelForVisualCausalLM
330
355
331
356
if is_diffusers_available ():
@@ -404,6 +429,7 @@ def _quantize_ovbasemodel(
404
429
"text_encoder_2" ,
405
430
"text_encoder_3" ,
406
431
]
432
+ self ._check_model_state (sub_model_names )
407
433
sub_models = filter (lambda x : x , (getattr (self .model , name ) for name in sub_model_names ))
408
434
for sub_model in sub_models :
409
435
_weight_only_quantization (sub_model .model , quantization_config_copy , ** kwargs )
@@ -421,6 +447,7 @@ def _quantize_ovbasemodel(
421
447
self .model .clear_requests ()
422
448
else :
423
449
# The model may be for example OVModelForImageClassification, OVModelForAudioClassification, etc.
450
+ self ._check_model_state ()
424
451
self .model .model = _hybrid_quantization (
425
452
self .model .model , quantization_config , calibration_dataset , ** kwargs
426
453
)
@@ -436,19 +463,31 @@ def _quantize_ovbasemodel(
436
463
"transformer" ,
437
464
"text_encoder_3" ,
438
465
]
466
+ self ._check_model_state (sub_model_names )
439
467
sub_models = filter (lambda x : x , (getattr (self .model , name ) for name in sub_model_names ))
440
468
for sub_model in sub_models :
441
469
_weight_only_quantization (sub_model .model , quantization_config , ** kwargs )
442
470
self .model .clear_requests ()
443
471
elif isinstance (self .model , OVModelForVisualCausalLM ):
444
472
language_model = self .model .language_model
445
- _weight_only_quantization (language_model .model , quantization_config , calibration_dataset , ** kwargs )
446
473
sub_model_names = ["vision_embeddings" , "text_embeddings" ] + self .model .additional_parts
474
+ self ._check_model_state (sub_model_names + ["language_model" ])
475
+ _weight_only_quantization (language_model .model , quantization_config , calibration_dataset , ** kwargs )
447
476
sub_models = [getattr (self .model , f"{ name } _model" ) for name in sub_model_names ]
448
477
for sub_model in sub_models :
449
478
_weight_only_quantization (sub_model , OVWeightQuantizationConfig (bits = 8 , sym = True ), ** kwargs )
450
479
self .model .clear_requests ()
480
+ elif isinstance (self .model , OVModelForSeq2SeqLM ):
481
+ sub_model_names = ["encoder" , "decoder" ]
482
+ if self .model .decoder_with_past is not None :
483
+ sub_model_names .append ("decoder_with_past" )
484
+ self ._check_model_state (sub_model_names )
485
+ sub_models = [getattr (self .model , name ) for name in sub_model_names ]
486
+ for sub_model in sub_models :
487
+ _weight_only_quantization (sub_model , quantization_config , ** kwargs )
488
+ self .model .clear_requests ()
451
489
else :
490
+ self ._check_model_state ()
452
491
_weight_only_quantization (self .model .model , quantization_config , calibration_dataset , ** kwargs )
453
492
self .model .request = None
454
493
else :
@@ -460,6 +499,7 @@ def _quantize_ovbasemodel(
460
499
461
500
# Quantize model(s)
462
501
if isinstance (self .model , _OVModelForWhisper ):
502
+ self ._check_model_state (["encoder_model" , "decoder_model" , "decoder_with_past_model" ])
463
503
self ._quantize_whisper_model (quantization_config , calibration_dataset , ** kwargs )
464
504
else :
465
505
quantized_model = _full_quantization (
0 commit comments