Skip to content

Commit c92856e

Browse files
committed
added test class for ast
1 parent cb952a9 commit c92856e

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

tests/bettertransformer/test_bettertransformer_audio.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
"openai/whisper-tiny",
2727
"patrickvonplaten/wav2vec2_tiny_random",
2828
"ybelkada/hubert-tiny-random",
29-
"MIT/ast-finetuned-audioset-10-10-0.4593",
29+
]
30+
31+
AST_TO_TEST = [
32+
"Ericwang/tiny-random-ast",
3033
]
3134

3235

@@ -56,6 +59,22 @@ def prepare_inputs_for_class(self, model_id):
5659
}
5760
return input_dict
5861

62+
class BetterTransformersASTTest(BetterTransformersTestMixin, unittest.TestCase):
63+
r"""
64+
Testing suite for AST - tests all the tests defined in `BetterTransformersTestMixin`
65+
Since `AST` uses slightly different preprocessor than other audio models, it is preferrable
66+
to define its own testing class.
67+
"""
68+
all_models_to_test = AST_TO_TEST
69+
70+
def prepare_inputs_for_class(self, model_id):
71+
batch_duration_in_seconds = [1, 3, 2, 6]
72+
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
73+
74+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
75+
76+
input_dict = feature_extractor(input_features, return_tensors="pt", padding=True)
77+
return input_dict
5978

6079
class BetterTransformersAudioTest(BetterTransformersTestMixin, unittest.TestCase):
6180
r"""

0 commit comments

Comments
 (0)