Skip to content

Commit 8879d0c

Browse files
committed
fix more tests
1 parent e3ffb2a commit 8879d0c

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

optimum/intel/openvino/quantization.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,10 @@ def _prepare_speech_to_text_calibration_data(self, config: OVQuantizationConfigB
878878
if decoder_w_p_model is not None:
879879
decoder_w_p_model.request = decoder_w_p_model.request.request
880880

881-
datasets = [nncf.Dataset(encoder_calibration_data), nncf.Dataset(decoder_calibration_data),]
881+
datasets = [
882+
nncf.Dataset(encoder_calibration_data),
883+
nncf.Dataset(decoder_calibration_data),
884+
]
882885
if decoder_w_p_model is not None:
883886
datasets.append(nncf.Dataset(decoder_w_p_calibration_data))
884887
return datasets

tests/openvino/test_modeling.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,9 @@ def test_seq2seq_load_from_hub(self):
607607
with TemporaryDirectory() as tmpdirname:
608608
ov_exported_pipe.save_pretrained(tmpdirname)
609609
folder_contents = os.listdir(tmpdirname)
610-
self.assertTrue(OV_DECODER_WITH_PAST_NAME in folder_contents)
611-
self.assertTrue(OV_DECODER_WITH_PAST_NAME.replace(".xml", ".bin") in folder_contents)
610+
if not ov_exported_pipe.model.decoder.stateful:
611+
self.assertTrue(OV_DECODER_WITH_PAST_NAME in folder_contents)
612+
self.assertTrue(OV_DECODER_WITH_PAST_NAME.replace(".xml", ".bin") in folder_contents)
612613
ov_exported_pipe = optimum_pipeline("text2text-generation", tmpdirname, accelerator="openvino")
613614
self.assertIsInstance(ov_exported_pipe.model, OVBaseModel)
614615

tests/openvino/test_quantization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1281,7 +1281,7 @@ def test_calibration_data_uniqueness(self, model_name, apply_caching):
12811281

12821282
for inputs_dict in calibration_data:
12831283
for k, v in inputs_dict.items():
1284-
if k == "input_ids":
1284+
if k in ["input_ids", "beam_idx"]:
12851285
continue
12861286

12871287
x = (v.numpy() if isinstance(v, torch.Tensor) else v).copy()

0 commit comments

Comments
 (0)