|
18 | 18 | import unittest
|
19 | 19 | from collections import defaultdict
|
20 | 20 | from functools import partial
|
21 |
| -from itertools import islice |
22 | 21 |
|
23 | 22 | import evaluate
|
24 | 23 | import numpy as np
|
@@ -599,27 +598,27 @@ def compute_metrics(p):
|
599 | 598 |
|
600 | 599 | class InferRequestWrapperTest(unittest.TestCase):
|
601 | 600 | MODEL_ID = ("openai/whisper-tiny.en",)
|
602 |
| - DATASET_ID = ("hf-internal-testing/librispeech_asr_dummy",) |
603 | 601 |
|
604 | 602 | @staticmethod
|
605 |
| - def _extract_input_features(processor, sample): |
| 603 | + def _generate_random_audio_data(processor): |
| 604 | + t = np.linspace(0, 1.0, int(1000), endpoint=False) |
| 605 | + audio_data = 0.5 * np.sin((2 + np.random.random()) * np.pi * t) |
606 | 606 | input_features = processor(
|
607 |
| - sample["audio"]["array"], |
608 |
| - sampling_rate=sample["audio"]["sampling_rate"], |
| 607 | + audio_data, |
| 608 | + sampling_rate=16000, |
609 | 609 | return_tensors="pt",
|
610 | 610 | ).input_features
|
611 | 611 | return input_features
|
612 | 612 |
|
613 |
| - @parameterized.expand(zip(MODEL_ID, DATASET_ID)) |
614 |
| - def test_calibration_data_uniqueness(self, model_id, dataset_id): |
| 613 | + @parameterized.expand(MODEL_ID) |
| 614 | + def test_calibration_data_uniqueness(self, model_id): |
615 | 615 | ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, compile=True)
|
616 | 616 | processor = AutoProcessor.from_pretrained(model_id)
|
617 | 617 |
|
618 |
| - dataset = load_dataset(dataset_id, "clean", split="validation") |
619 | 618 | calibration_data = []
|
620 | 619 | ov_model.decoder_with_past.request = InferRequestWrapper(ov_model.decoder_with_past.request, calibration_data)
|
621 |
| - for data in islice(dataset, 2): |
622 |
| - input_features = self._extract_input_features(processor, data) |
| 620 | + for _ in range(2): |
| 621 | + input_features = self._generate_random_audio_data(processor) |
623 | 622 | ov_model.generate(input_features)
|
624 | 623 |
|
625 | 624 | data_hashes_per_key = defaultdict(list)
|
|
0 commit comments