Skip to content

Commit 27b30ac

Browse files
committed
fix quantization
1 parent bb18f27 commit 27b30ac

File tree

2 files changed

+35
-27
lines changed

2 files changed

+35
-27
lines changed

optimum/intel/openvino/modeling_seq2seq.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ def __init__(self, model: openvino.runtime.Model, parent_model: OVModelForSeq2Se
577577
is_legacy = any("past_key_values" in key.get_any_name() for key in self.model.outputs)
578578
self.use_past = len(self.key_value_input_names) > 0 or self.stateful
579579
self.next_beam_idx = None
580+
self._past_length = 0
580581

581582
if len(self.key_value_input_names) > 0 and not is_legacy:
582583
self.use_past = True
@@ -625,7 +626,7 @@ def forward(
625626

626627
if self.stateful and past_key_values is None:
627628
self.request.reset_state()
628-
self._past_len = 0
629+
self._past_length = 0
629630

630631
if past_key_values is not None and not self.stateful:
631632
# Flatten the past_key_values
@@ -664,7 +665,7 @@ def forward(
664665
self.request.start_async(inputs, share_inputs=True)
665666
self.request.wait()
666667
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)
667-
self._past_len += input_ids.shape[1]
668+
self._past_length += input_ids.shape[1]
668669

669670
out_past_key_values = ()
670671

@@ -689,6 +690,13 @@ def forward(
689690

690691
return Seq2SeqLMOutput(logits=logits, past_key_values=out_past_key_values)
691692

693+
def _get_past_length(self, past_key_values=None):
694+
if past_key_values is None:
695+
return 0
696+
if self.stateful:
697+
return self._past_length
698+
return past_key_values[0][0].shape[-2]
699+
692700
def __call__(self, *args, **kwargs):
693701
return self.forward(*args, **kwargs)
694702

@@ -1074,10 +1082,7 @@ def prepare_inputs_for_generation(
10741082

10751083
past_length = 0
10761084
if past_key_values is not None:
1077-
if self.decoder.stateful:
1078-
past_length = getattr(self.decoder, "_past_len", 0)
1079-
else:
1080-
past_length = past_key_values[0][0].shape[2]
1085+
self.decoder._get_past_length(past_key_values)
10811086

10821087
# Some generation methods already pass only the last input ID
10831088
if decoder_input_ids.shape[1] > past_length:

optimum/intel/openvino/quantization.py

+24-21
Original file line numberDiff line numberDiff line change
@@ -828,12 +828,14 @@ def _prepare_speech_to_text_calibration_data(self, config: OVQuantizationConfigB
828828
decoder_model.request, decoder_calibration_data, apply_caching=True
829829
)
830830

831-
decoder_w_p_calibration_data = []
832-
decoder_w_p_model = self.model.decoder_with_past
833-
decoder_w_p_model._compile()
834-
decoder_w_p_model.request = InferRequestWrapper(
835-
decoder_w_p_model.request, decoder_w_p_calibration_data, apply_caching=True
836-
)
831+
decoder_w_p_model = None
832+
if self.model.decoder_with_past_model is not None:
833+
decoder_w_p_calibration_data = []
834+
decoder_w_p_model = self.model.decoder_with_past
835+
decoder_w_p_model._compile()
836+
decoder_w_p_model.request = InferRequestWrapper(
837+
decoder_w_p_model.request, decoder_w_p_calibration_data, apply_caching=True
838+
)
837839

838840
dataset_metadata = PREDEFINED_SPEECH_TO_TEXT_DATASETS[config.dataset]
839841

@@ -867,13 +869,13 @@ def _prepare_speech_to_text_calibration_data(self, config: OVQuantizationConfigB
867869
finally:
868870
encoder_model.request = encoder_model.request.request
869871
decoder_model.request = decoder_model.request.request
870-
decoder_w_p_model.request = decoder_w_p_model.request.request
872+
if decoder_w_p_model is not None:
873+
decoder_w_p_model.request = decoder_w_p_model.request.request
871874

872-
return (
873-
nncf.Dataset(encoder_calibration_data),
874-
nncf.Dataset(decoder_calibration_data),
875-
nncf.Dataset(decoder_w_p_calibration_data),
876-
)
875+
datasets = [nncf.Dataset(encoder_calibration_data), nncf.Dataset(decoder_calibration_data),]
876+
if decoder_w_p_model is not None:
877+
datasets.append(nncf.Dataset(decoder_w_p_calibration_data))
878+
return datasets
877879

878880
def _prepare_text_generation_calibration_data(
879881
self, quantization_config: OVQuantizationConfigBase, calibration_dataloader: OVDataLoader
@@ -986,15 +988,16 @@ def _quantize_whisper_model(self, quantization_config, calibration_dataset, **kw
986988
self.model.decoder.model = quantized_decoder_model
987989
self.model.decoder.request = None
988990

989-
# Quantize decoder with past model
990-
config = copy.deepcopy(quantization_config)
991-
config.num_samples = calibration_dataset[2].get_length()
992-
quantized_decoder_w_p_model = _full_quantization(
993-
self.model.decoder_with_past_model, config, calibration_dataset[2], **kwargs
994-
)
995-
self.model.decoder_with_past_model = quantized_decoder_w_p_model
996-
self.model.decoder_with_past.model = quantized_decoder_w_p_model
997-
self.model.decoder_with_past.request = None
991+
if self.model.decoder_with_past_model is not None:
992+
# Quantize decoder with past model
993+
config = copy.deepcopy(quantization_config)
994+
config.num_samples = calibration_dataset[2].get_length()
995+
quantized_decoder_w_p_model = _full_quantization(
996+
self.model.decoder_with_past_model, config, calibration_dataset[2], **kwargs
997+
)
998+
self.model.decoder_with_past_model = quantized_decoder_w_p_model
999+
self.model.decoder_with_past.model = quantized_decoder_w_p_model
1000+
self.model.decoder_with_past.request = None
9981001

9991002

10001003
def _weight_only_quantization(

0 commit comments

Comments
 (0)