Skip to content

Commit de9a776

Browse files
committed
fix tests
1 parent 3061342 commit de9a776

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

tests/openvino/test_exporters_cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def test_exporters_cli_int8(self, task: str, model_type: str):
269269

270270
if task.startswith("text2text-generation"):
271271
models = [model.encoder, model.decoder]
272-
if task.endswith("with-past"):
272+
if task.endswith("with-past") and not model.decoder.stateful:
273273
models.append(model.decoder_with_past)
274274
elif model_type.startswith("stable-diffusion") or model_type.startswith("flux"):
275275
models = [model.unet or model.transformer, model.vae_encoder, model.vae_decoder]

tests/openvino/test_modeling.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1527,8 +1527,9 @@ def test_compare_to_transformers(self, model_arch):
15271527

15281528
self.assertIsInstance(ov_model.encoder, OVEncoder)
15291529
self.assertIsInstance(ov_model.decoder, OVDecoder)
1530-
self.assertIsInstance(ov_model.decoder_with_past, OVDecoder)
1531-
self.assertIsInstance(ov_model.config, PretrainedConfig)
1530+
if not ov_model.decoder.stateful:
1531+
self.assertIsInstance(ov_model.decoder_with_past, OVDecoder)
1532+
self.assertIsInstance(ov_model.config, PretrainedConfig)
15321533

15331534
transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
15341535
tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -1613,7 +1614,7 @@ def test_generate_utils(self, model_arch):
16131614
gc.collect()
16141615

16151616
def test_compare_with_and_without_past_key_values(self):
1616-
model_id = MODEL_NAMES["t5"]
1617+
model_id = MODEL_NAMES["bart"]
16171618
tokenizer = AutoTokenizer.from_pretrained(model_id)
16181619
text = "This is a sample input"
16191620
tokens = tokenizer(text, return_tensors="pt")

tests/openvino/test_quantization.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class OVQuantizerTest(unittest.TestCase):
106106
weight_only=False,
107107
smooth_quant_alpha=0.95,
108108
),
109-
(14, 22, 21) if is_transformers_version("<=", "4.42.4") else (14, 22, 25),
109+
(14, 22, 21) if is_transformers_version("<=", "4.42.4") else (14, 26, 25),
110110
(14, 21, 17) if is_transformers_version("<=", "4.42.4") else (14, 22, 18),
111111
),
112112
]
@@ -1213,7 +1213,7 @@ def test_calibration_data_uniqueness(self, model_name, apply_caching):
12131213
processor = AutoProcessor.from_pretrained(model_id)
12141214

12151215
calibration_data = []
1216-
if not ov_model.stateful:
1216+
if not ov_model.decoder.stateful:
12171217
ov_model.decoder_with_past.request = InferRequestWrapper(
12181218
ov_model.decoder_with_past.request, calibration_data, apply_caching=apply_caching
12191219
)

0 commit comments

Comments
 (0)