|
66 | 66 | from utils_tests import MODEL_NAMES, TEST_IMAGE_URL, mock_torch_cuda_is_available, patch_awq_for_inference
|
67 | 67 |
|
68 | 68 | from optimum.exporters.openvino.model_patcher import patch_update_causal_mask
|
| 69 | +from optimum.exporters.openvino.stateful import model_has_state |
69 | 70 | from optimum.intel import (
|
70 | 71 | OVDiffusionPipeline,
|
71 | 72 | OVFluxPipeline,
|
@@ -1625,12 +1626,18 @@ class OVModelForSeq2SeqLMIntegrationTest(unittest.TestCase):
|
1625 | 1626 | GENERATION_LENGTH = 100
|
1626 | 1627 | SPEEDUP_CACHE = 1.1
|
1627 | 1628 |
|
| 1629 | + SUPPORT_STATEFUL = ("t5", "mt5") |
| 1630 | + |
1628 | 1631 | @parameterized.expand(SUPPORTED_ARCHITECTURES)
|
1629 | 1632 | def test_compare_to_transformers(self, model_arch):
|
1630 | 1633 | model_id = MODEL_NAMES[model_arch]
|
1631 | 1634 | set_seed(SEED)
|
1632 | 1635 | ov_model = OVModelForSeq2SeqLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
|
1633 |
| - |
| 1636 | + expected_stateful = is_transformers_version(">", "4.43") and model_arch in self.SUPPORT_STATEFUL |
| 1637 | + self.assertEqual(ov_model.decoder.stateful, expected_stateful) |
| 1638 | + self.assertEqual(model_has_state(ov_model.decoder_model), expected_stateful) |
| 1639 | + check_with_past_available = self.assertIsNone if expected_stateful else self.assertIsNotNone |
| 1640 | + check_with_past_available(ov_model.decoder_with_past) |
1634 | 1641 | self.assertIsInstance(ov_model.encoder, OVEncoder)
|
1635 | 1642 | self.assertIsInstance(ov_model.decoder, OVDecoder)
|
1636 | 1643 | if not ov_model.decoder.stateful:
|
@@ -2339,6 +2346,12 @@ def test_compare_to_transformers(self, model_arch):
|
2339 | 2346 | transformers_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
|
2340 | 2347 | ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
|
2341 | 2348 | self.assertIsInstance(ov_model.config, PretrainedConfig)
|
| 2349 | + # whisper cache class support implemented in 4.43 |
| 2350 | + expected_stateful = is_transformers_version(">", "4.43") |
| 2351 | + self.assertEqual(ov_model.decoder.stateful, expected_stateful) |
| 2352 | + self.assertEqual(model_has_state(ov_model.decoder_model), expected_stateful) |
| 2353 | + check_with_past_available = self.assertIsNone if expected_stateful else self.assertIsNotNone |
| 2354 | + check_with_past_available(ov_model.decoder_with_past) |
2342 | 2355 |
|
2343 | 2356 | processor = get_preprocessor(model_id)
|
2344 | 2357 | data = self._generate_random_audio_data()
|
|
0 commit comments