@@ -828,12 +828,14 @@ def _prepare_speech_to_text_calibration_data(self, config: OVQuantizationConfigB
828
828
decoder_model .request , decoder_calibration_data , apply_caching = True
829
829
)
830
830
831
- decoder_w_p_calibration_data = []
832
- decoder_w_p_model = self .model .decoder_with_past
833
- decoder_w_p_model ._compile ()
834
- decoder_w_p_model .request = InferRequestWrapper (
835
- decoder_w_p_model .request , decoder_w_p_calibration_data , apply_caching = True
836
- )
831
+ decoder_w_p_model = None
832
+ if self .model .decoder_with_past_model is not None :
833
+ decoder_w_p_calibration_data = []
834
+ decoder_w_p_model = self .model .decoder_with_past
835
+ decoder_w_p_model ._compile ()
836
+ decoder_w_p_model .request = InferRequestWrapper (
837
+ decoder_w_p_model .request , decoder_w_p_calibration_data , apply_caching = True
838
+ )
837
839
838
840
dataset_metadata = PREDEFINED_SPEECH_TO_TEXT_DATASETS [config .dataset ]
839
841
@@ -867,13 +869,13 @@ def _prepare_speech_to_text_calibration_data(self, config: OVQuantizationConfigB
867
869
finally :
868
870
encoder_model .request = encoder_model .request .request
869
871
decoder_model .request = decoder_model .request .request
870
- decoder_w_p_model .request = decoder_w_p_model .request .request
872
+ if decoder_w_p_model is not None :
873
+ decoder_w_p_model .request = decoder_w_p_model .request .request
871
874
872
- return (
873
- nncf .Dataset (encoder_calibration_data ),
874
- nncf .Dataset (decoder_calibration_data ),
875
- nncf .Dataset (decoder_w_p_calibration_data ),
876
- )
875
+ datasets = [nncf .Dataset (encoder_calibration_data ), nncf .Dataset (decoder_calibration_data ),]
876
+ if decoder_w_p_model is not None :
877
+ datasets .append (nncf .Dataset (decoder_w_p_calibration_data ))
878
+ return datasets
877
879
878
880
def _prepare_text_generation_calibration_data (
879
881
self , quantization_config : OVQuantizationConfigBase , calibration_dataloader : OVDataLoader
@@ -986,15 +988,16 @@ def _quantize_whisper_model(self, quantization_config, calibration_dataset, **kw
986
988
self .model .decoder .model = quantized_decoder_model
987
989
self .model .decoder .request = None
988
990
989
- # Quantize decoder with past model
990
- config = copy .deepcopy (quantization_config )
991
- config .num_samples = calibration_dataset [2 ].get_length ()
992
- quantized_decoder_w_p_model = _full_quantization (
993
- self .model .decoder_with_past_model , config , calibration_dataset [2 ], ** kwargs
994
- )
995
- self .model .decoder_with_past_model = quantized_decoder_w_p_model
996
- self .model .decoder_with_past .model = quantized_decoder_w_p_model
997
- self .model .decoder_with_past .request = None
991
+ if self .model .decoder_with_past_model is not None :
992
+ # Quantize decoder with past model
993
+ config = copy .deepcopy (quantization_config )
994
+ config .num_samples = calibration_dataset [2 ].get_length ()
995
+ quantized_decoder_w_p_model = _full_quantization (
996
+ self .model .decoder_with_past_model , config , calibration_dataset [2 ], ** kwargs
997
+ )
998
+ self .model .decoder_with_past_model = quantized_decoder_w_p_model
999
+ self .model .decoder_with_past .model = quantized_decoder_w_p_model
1000
+ self .model .decoder_with_past .request = None
998
1001
999
1002
1000
1003
def _weight_only_quantization (
0 commit comments