Skip to content

Commit b12dca9

Browse files
committed
update test to check that stateful expected
1 parent 9219632 commit b12dca9

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

tests/openvino/test_modeling.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from utils_tests import MODEL_NAMES, TEST_IMAGE_URL, mock_torch_cuda_is_available, patch_awq_for_inference
6767

6868
from optimum.exporters.openvino.model_patcher import patch_update_causal_mask
69+
from optimum.exporters.openvino.stateful import model_has_state
6970
from optimum.intel import (
7071
OVDiffusionPipeline,
7172
OVFluxPipeline,
@@ -1625,12 +1626,18 @@ class OVModelForSeq2SeqLMIntegrationTest(unittest.TestCase):
16251626
GENERATION_LENGTH = 100
16261627
SPEEDUP_CACHE = 1.1
16271628

1629+
SUPPORT_STATEFUL = ("t5", "mt5")
1630+
16281631
@parameterized.expand(SUPPORTED_ARCHITECTURES)
16291632
def test_compare_to_transformers(self, model_arch):
16301633
model_id = MODEL_NAMES[model_arch]
16311634
set_seed(SEED)
16321635
ov_model = OVModelForSeq2SeqLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
1633-
1636+
expected_stateful = is_transformers_version(">", "4.43") and model_arch in self.SUPPORT_STATEFUL
1637+
self.assertEqual(ov_model.decoder.stateful, expected_stateful)
1638+
self.assertEqual(model_has_state(ov_model.decoder_model), expected_stateful)
1639+
check_with_past_available = self.assertIsNone if expected_stateful else self.assertIsNotNone
1640+
check_with_past_available(ov_model.decoder_with_past)
16341641
self.assertIsInstance(ov_model.encoder, OVEncoder)
16351642
self.assertIsInstance(ov_model.decoder, OVDecoder)
16361643
if not ov_model.decoder.stateful:
@@ -2339,6 +2346,12 @@ def test_compare_to_transformers(self, model_arch):
23392346
transformers_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
23402347
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
23412348
self.assertIsInstance(ov_model.config, PretrainedConfig)
2349+
# whisper cache class support implemented in 4.43
2350+
expected_stateful = is_transformers_version(">", "4.43")
2351+
self.assertEqual(ov_model.decoder.stateful, expected_stateful)
2352+
self.assertEqual(model_has_state(ov_model.decoder_model), expected_stateful)
2353+
check_with_past_available = self.assertIsNone if expected_stateful else self.assertIsNotNone
2354+
check_with_past_available(ov_model.decoder_with_past)
23422355

23432356
processor = get_preprocessor(model_id)
23442357
data = self._generate_random_audio_data()

0 commit comments

Comments
 (0)