Skip to content

Commit 0249b17

Browse files
committed
fix more tests
1 parent de9a776 commit 0249b17

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

optimum/intel/openvino/quantization.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,10 @@ def _prepare_speech_to_text_calibration_data(self, config: OVQuantizationConfigB
872872
if decoder_w_p_model is not None:
873873
decoder_w_p_model.request = decoder_w_p_model.request.request
874874

875-
datasets = [nncf.Dataset(encoder_calibration_data), nncf.Dataset(decoder_calibration_data),]
875+
datasets = [
876+
nncf.Dataset(encoder_calibration_data),
877+
nncf.Dataset(decoder_calibration_data),
878+
]
876879
if decoder_w_p_model is not None:
877880
datasets.append(nncf.Dataset(decoder_w_p_calibration_data))
878881
return datasets

tests/openvino/test_modeling.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -535,8 +535,9 @@ def test_seq2seq_load_from_hub(self):
535535
with TemporaryDirectory() as tmpdirname:
536536
ov_exported_pipe.save_pretrained(tmpdirname)
537537
folder_contents = os.listdir(tmpdirname)
538-
self.assertTrue(OV_DECODER_WITH_PAST_NAME in folder_contents)
539-
self.assertTrue(OV_DECODER_WITH_PAST_NAME.replace(".xml", ".bin") in folder_contents)
538+
if not ov_exported_pipe.model.decoder.stateful:
539+
self.assertTrue(OV_DECODER_WITH_PAST_NAME in folder_contents)
540+
self.assertTrue(OV_DECODER_WITH_PAST_NAME.replace(".xml", ".bin") in folder_contents)
540541
ov_exported_pipe = optimum_pipeline("text2text-generation", tmpdirname, accelerator="openvino")
541542
self.assertIsInstance(ov_exported_pipe.model, OVBaseModel)
542543

tests/openvino/test_quantization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,7 @@ def test_calibration_data_uniqueness(self, model_name, apply_caching):
12301230

12311231
for inputs_dict in calibration_data:
12321232
for k, v in inputs_dict.items():
1233-
if k == "input_ids":
1233+
if k in ["input_ids", "beam_idx"]:
12341234
continue
12351235

12361236
x = (v.numpy() if isinstance(v, torch.Tensor) else v).copy()

0 commit comments

Comments
 (0)