@@ -315,31 +315,6 @@ def quantize(
315
315
else :
316
316
raise TypeError (f"Unsupported model type: { type (self .model )} " )
317
317
318
- def _check_model_state (self , sub_model_names : List [str ] = None ):
319
- message_template = (
320
- "Couldn't apply optimization to the model because it was already compressed with config: {}. "
321
- "To avoid this issue, set load_in_8bit=False in the from_pretrained method when using the optimum-intel API, "
322
- "or explicitly specify the desired weight format using --weight_format fp16/fp32 for CLI."
323
- )
324
-
325
- def check_rt_info (ov_model ):
326
- rt_info = ov_model .get_rt_info ()
327
- if "nncf" in rt_info :
328
- model_weight_compression_config = rt_info ["nncf" ].get ("weight_compression" , None )
329
- model_quantization_config = rt_info ["nncf" ].get ("quantization" , None )
330
- if model_weight_compression_config is not None :
331
- raise RuntimeError (message_template .format (model_weight_compression_config ))
332
- elif model_quantization_config is not None :
333
- raise RuntimeError (message_template .format (model_quantization_config ))
334
-
335
- if sub_model_names is None :
336
- check_rt_info (self .model .model )
337
- else :
338
- for name in sub_model_names :
339
- if hasattr (self .model , name ):
340
- ov_model = getattr (self .model , name ).model
341
- check_rt_info (ov_model )
342
-
343
318
def _quantize_ovbasemodel (
344
319
self ,
345
320
ov_config : OVConfig ,
@@ -350,7 +325,7 @@ def _quantize_ovbasemodel(
350
325
remove_unused_columns : bool = True ,
351
326
** kwargs ,
352
327
):
353
- from optimum .intel .openvino .modeling_seq2seq import _OVModelForWhisper , OVModelForSeq2SeqLM
328
+ from optimum .intel .openvino .modeling_seq2seq import _OVModelForWhisper
354
329
from optimum .intel .openvino .modeling_visual_language import OVModelForVisualCausalLM
355
330
356
331
if is_diffusers_available ():
@@ -429,7 +404,6 @@ def _quantize_ovbasemodel(
429
404
"text_encoder_2" ,
430
405
"text_encoder_3" ,
431
406
]
432
- self ._check_model_state (sub_model_names )
433
407
sub_models = filter (lambda x : x , (getattr (self .model , name ) for name in sub_model_names ))
434
408
for sub_model in sub_models :
435
409
_weight_only_quantization (sub_model .model , quantization_config_copy , ** kwargs )
@@ -447,7 +421,6 @@ def _quantize_ovbasemodel(
447
421
self .model .clear_requests ()
448
422
else :
449
423
# The model may be for example OVModelForImageClassification, OVModelForAudioClassification, etc.
450
- self ._check_model_state ()
451
424
self .model .model = _hybrid_quantization (
452
425
self .model .model , quantization_config , calibration_dataset , ** kwargs
453
426
)
@@ -463,31 +436,19 @@ def _quantize_ovbasemodel(
463
436
"transformer" ,
464
437
"text_encoder_3" ,
465
438
]
466
- self ._check_model_state (sub_model_names )
467
439
sub_models = filter (lambda x : x , (getattr (self .model , name ) for name in sub_model_names ))
468
440
for sub_model in sub_models :
469
441
_weight_only_quantization (sub_model .model , quantization_config , ** kwargs )
470
442
self .model .clear_requests ()
471
443
elif isinstance (self .model , OVModelForVisualCausalLM ):
472
444
language_model = self .model .language_model
473
- sub_model_names = ["vision_embeddings" , "text_embeddings" ] + self .model .additional_parts
474
- self ._check_model_state (sub_model_names + ["language_model" ])
475
445
_weight_only_quantization (language_model .model , quantization_config , calibration_dataset , ** kwargs )
446
+ sub_model_names = ["vision_embeddings" , "text_embeddings" ] + self .model .additional_parts
476
447
sub_models = [getattr (self .model , f"{ name } _model" ) for name in sub_model_names ]
477
448
for sub_model in sub_models :
478
449
_weight_only_quantization (sub_model , OVWeightQuantizationConfig (bits = 8 , sym = True ), ** kwargs )
479
450
self .model .clear_requests ()
480
- elif isinstance (self .model , OVModelForSeq2SeqLM ):
481
- sub_model_names = ["encoder" , "decoder" ]
482
- if self .model .decoder_with_past is not None :
483
- sub_model_names .append ("decoder_with_past" )
484
- self ._check_model_state (sub_model_names )
485
- sub_models = [getattr (self .model , name ) for name in sub_model_names ]
486
- for sub_model in sub_models :
487
- _weight_only_quantization (sub_model , quantization_config , ** kwargs )
488
- self .model .clear_requests ()
489
451
else :
490
- self ._check_model_state ()
491
452
_weight_only_quantization (self .model .model , quantization_config , calibration_dataset , ** kwargs )
492
453
self .model .request = None
493
454
else :
@@ -499,7 +460,6 @@ def _quantize_ovbasemodel(
499
460
500
461
# Quantize model(s)
501
462
if isinstance (self .model , _OVModelForWhisper ):
502
- self ._check_model_state (["encoder_model" , "decoder_model" , "decoder_with_past_model" ])
503
463
self ._quantize_whisper_model (quantization_config , calibration_dataset , ** kwargs )
504
464
else :
505
465
quantized_model = _full_quantization (
@@ -1050,6 +1010,7 @@ def _weight_only_quantization(
1050
1010
calibration_dataset : Optional [Union [nncf .Dataset , Iterable ]] = None ,
1051
1011
** kwargs ,
1052
1012
) -> openvino .runtime .Model :
1013
+ _verify_not_optimized (model )
1053
1014
config = quantization_config
1054
1015
if isinstance (config , dict ):
1055
1016
config = OVWeightQuantizationConfig .from_dict (quantization_config )
@@ -1106,6 +1067,7 @@ def _full_quantization(
1106
1067
calibration_dataset : nncf .Dataset ,
1107
1068
** kwargs ,
1108
1069
):
1070
+ _verify_not_optimized (model )
1109
1071
advanced_parameters_kwargs = {}
1110
1072
if quantization_config .smooth_quant_alpha is not None :
1111
1073
advanced_parameters_kwargs ["smooth_quant_alphas" ] = AdvancedSmoothQuantParameters (
@@ -1227,3 +1189,20 @@ def _hybrid_quantization(
1227
1189
** kwargs ,
1228
1190
)
1229
1191
return quantized_model
1192
+
1193
+
1194
+ def _verify_not_optimized (ov_model ):
1195
+ message_template = (
1196
+ "Cannot apply optimization to the model because it was already optimized with the following config: {}. "
1197
+ "To avoid this issue, check that you set load_in_8bit=False or not using quantization_config at export in the .from_pretrained(), "
1198
+ "or explicitly specify weight format with --weight_format fp16/fp32 when using CLI."
1199
+ )
1200
+
1201
+ rt_info = ov_model .get_rt_info ()
1202
+ if "nncf" in rt_info :
1203
+ model_weight_compression_config = rt_info ["nncf" ].get ("weight_compression" , None )
1204
+ model_quantization_config = rt_info ["nncf" ].get ("quantization" , None )
1205
+ if model_weight_compression_config is not None :
1206
+ raise RuntimeError (message_template .format (model_weight_compression_config ))
1207
+ elif model_quantization_config is not None :
1208
+ raise RuntimeError (message_template .format (model_quantization_config ))
0 commit comments