Skip to content

Commit 6a73453

Browse files
Audio models support with optimum.exporters.onnx (#622)
* Audio models, first draft * Hubert, Wav2Vec and Sew work * Add support for the other models * Almost done * Add architecture names to doc * Fixed issue with SEW-D and WavLM * Add tiny tests * Fix * Add models to tiny * Add Data2Vec audio * Fix dummy input generator tests * Update doc * Fix tests Co-authored-by: Michael Benayoun <michael@huggingface.co>
1 parent e199283 commit 6a73453

File tree

8 files changed

+255
-26
lines changed

8 files changed

+255
-26
lines changed

docs/source/exporters/onnx/package_reference/configuration.mdx

+10
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ They specify which input generators should be used for the dummy inputs, but rem
6464

6565
## Supported architectures
6666

67+
- Audio Spectrogram Transformer
6768
- Albert
6869
- Bart
6970
- Beit
@@ -78,6 +79,7 @@ They specify which input generators should be used for the dummy inputs, but rem
7879
- CodeGen
7980
- ConvBert
8081
- ConvNext
82+
- Data2VecAudio
8183
- Data2VecText
8284
- Data2VecVision
8385
- Deberta
@@ -91,6 +93,7 @@ They specify which input generators should be used for the dummy inputs, but rem
9193
- GPT-J
9294
- GPT-Neo
9395
- GroupVit
96+
- Hubert
9497
- IBert
9598
- LayoutLM
9699
- LayoutLM-v3
@@ -110,10 +113,17 @@ They specify which input generators should be used for the dummy inputs, but rem
110113
- Roberta
111114
- Roformer
112115
- Segformer
116+
- SEW
117+
- Speech2Text
113118
- SqueezeBert
114119
- Stable Diffusion
115120
- T5
121+
- UniSpeech
122+
- UniSpeech SAT
116123
- Vit
124+
- Wav2Vec2
125+
- Wav2Vec2 Conformer
126+
- WavLM
117127
- Whisper
118128
- XLM
119129
- XLM-Roberta

optimum/exporters/onnx/base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,11 @@ class OnnxConfig(ExportConfig, ABC):
158158
),
159159
"sequence-classification": OrderedDict({"logits": {0: "batch_size"}}),
160160
"token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
161-
"speech2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
161+
"speech2seq-lm": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
162+
"audio-classification": OrderedDict({"logits": {0: "batch_size"}}),
163+
"audio-frame-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
164+
"audio-ctc": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
165+
"audio-xvector": OrderedDict({"logits": {0: "batch_size"}, "embeddings": {0: "batch_size"}}),
162166
}
163167

164168
def __init__(

optimum/exporters/onnx/config.py

+4
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ class AudioOnnxConfig(OnnxConfig):
162162

163163
DUMMY_INPUT_GENERATOR_CLASSES = (DummyAudioInputGenerator,)
164164

165+
@property
166+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
167+
return {"input_values": {0: "batch_size", 1: "sequence_length"}}
168+
165169

166170
class AudioToTextOnnxConfig(OnnxSeq2SeqConfigWithPast):
167171
DUMMY_INPUT_GENERATOR_CLASSES = (

optimum/exporters/onnx/model_configs.py

+101-15
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from ...utils import (
2222
DEFAULT_DUMMY_SHAPES,
23+
DummyAudioInputGenerator,
2324
DummyDecoderTextInputGenerator,
2425
DummyPastKeyValuesGenerator,
2526
DummySeq2SeqDecoderTextInputGenerator,
@@ -34,8 +35,9 @@
3435
NormalizedVisionConfig,
3536
logging,
3637
)
37-
from .base import ConfigBehavior, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
38+
from .base import ConfigBehavior, OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
3839
from .config import (
40+
AudioOnnxConfig,
3941
AudioToTextOnnxConfig,
4042
TextAndVisionOnnxConfig,
4143
TextDecoderOnnxConfig,
@@ -514,6 +516,18 @@ class SegformerOnnxConfig(YolosOnnxConfig):
514516
pass
515517

516518

519+
class MobileNetV1OnnxConfig(ViTOnnxConfig):
520+
ATOL_FOR_VALIDATION = 1e-4
521+
522+
@property
523+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
524+
return {"pixel_values": {0: "batch_size"}}
525+
526+
527+
class MobileNetV2OnnxConfig(MobileNetV1OnnxConfig):
528+
pass
529+
530+
517531
class CLIPNormalizedConfig(NormalizedTextAndVisionConfig):
518532
TEXT_CONFIG = "text_config"
519533
VISION_CONFIG = "vision_config"
@@ -693,11 +707,9 @@ class Data2VecVisionOnnxConfig(ViTOnnxConfig):
693707
pass
694708

695709

696-
# TODO: add support when audio models are supported.
697-
class Data2VecAudioOnnxConfig(ViTOnnxConfig):
698-
@property
699-
def inputs(self):
700-
raise NotImplementedError
710+
class Data2VecAudioOnnxConfig(AudioOnnxConfig):
711+
NORMALIZED_CONFIG_CLASS = NormalizedConfig
712+
ATOL_FOR_VALIDATION = 1e-4
701713

702714

703715
class PerceiverDummyInputGenerator(DummyVisionInputGenerator):
@@ -751,20 +763,94 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
751763
return dummy_inputs
752764

753765

754-
class WhisperOnnxConfig(AudioToTextOnnxConfig):
755-
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig
756-
ATOL_FOR_VALIDATION = 1e-3
766+
class HubertOnnxConfig(AudioOnnxConfig):
767+
NORMALIZED_CONFIG_CLASS = NormalizedConfig
757768

758769

759-
class MobileNetV1OnnxConfig(VisionOnnxConfig):
760-
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
761-
MIN_TORCH_VERSION = version.parse("1.11")
770+
class Wav2Vec2OnnxConfig(HubertOnnxConfig):
771+
pass
772+
773+
774+
class Wav2Vec2ConformerOnnxConfig(HubertOnnxConfig):
775+
pass
776+
777+
778+
class SEWOnnxConfig(HubertOnnxConfig):
779+
pass
780+
781+
782+
class SEWDOnnxConfig(HubertOnnxConfig):
783+
DEFAULT_ONNX_OPSET = 12
784+
785+
786+
class UniSpeechOnnxConfig(HubertOnnxConfig):
787+
pass
788+
789+
790+
class UniSpeechSATOnnxConfig(HubertOnnxConfig):
791+
pass
792+
793+
794+
class WavLMOnnxConfig(HubertOnnxConfig):
795+
DEFAULT_ONNX_OPSET = 12
796+
797+
798+
class ASTDummyAudioInputGenerator(DummyAudioInputGenerator):
799+
def generate(self, input_name: str, framework: str = "pt"):
800+
shape = [self.batch_size, self.normalized_config.max_length, self.normalized_config.num_mel_bins]
801+
if input_name == "input_values":
802+
return self.random_float_tensor(shape, min_value=-1, max_value=1, framework=framework)
803+
return super().generate(input_name, framework=framework)
804+
805+
806+
class ASTOnnxConfig(OnnxConfig):
807+
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
808+
num_mel_bins="num_mel_bins", max_length="max_length", allow_new=True
809+
)
810+
DUMMY_INPUT_GENERATOR_CLASSES = (ASTDummyAudioInputGenerator,)
762811
ATOL_FOR_VALIDATION = 1e-4
763812

764813
@property
765814
def inputs(self) -> Mapping[str, Mapping[int, str]]:
766-
return {"pixel_values": {0: "batch"}}
815+
return {"input_values": {0: "batch_size"}}
767816

768817

769-
class MobileNetV2OnnxConfig(MobileNetV1OnnxConfig):
770-
pass
818+
# TODO: currently disabled because an operator seems not supported by ONNX.
819+
# class MCTCTDummyAudioInputGenerator(DummyAudioInputGenerator):
820+
# def generate(self, input_name: str, framework: str = "pt"):
821+
# shape = [self.batch_size, self.sequence_length, self.normalized_config.input_features_per_channel]
822+
# if input_name == "input_features":
823+
# return self.random_float_tensor(shape, min_value=-1, max_value=1, framework=framework)
824+
# return super().generate(input_name, framework=framework)
825+
#
826+
#
827+
# class MCTCTOnnxConfig(OnnxConfig):
828+
# NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(input_features_per_channel="input_feat_per_channel", allow_new=True)
829+
# DUMMY_INPUT_GENERATOR_CLASSES = (MCTCTDummyAudioInputGenerator,)
830+
# DEFAULT_ONNX_OPSET = 13
831+
#
832+
# @property
833+
# def inputs(self) -> Mapping[str, Mapping[int, str]]:
834+
# return {"input_features": {0: "batch_size", 1: "sequence_classification"}}
835+
836+
837+
class WhisperOnnxConfig(AudioToTextOnnxConfig):
838+
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig
839+
ATOL_FOR_VALIDATION = 1e-3
840+
841+
842+
class Speech2TextDummyAudioInputGenerator(DummyAudioInputGenerator):
843+
def generate(self, input_name: str, framework: str = "pt"):
844+
shape = [self.batch_size, self.sequence_length, self.normalized_config.input_features_per_channel]
845+
if input_name == "input_features":
846+
return self.random_float_tensor(shape, min_value=-1, max_value=1, framework=framework)
847+
return super().generate(input_name, framework=framework)
848+
849+
850+
class Speech2TextOnnxConfig(AudioToTextOnnxConfig):
851+
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
852+
input_features_per_channel="input_feat_per_channel", allow_new=True
853+
)
854+
DUMMY_INPUT_GENERATOR_CLASSES = (
855+
Speech2TextDummyAudioInputGenerator,
856+
) + AudioToTextOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:]

optimum/exporters/tasks.py

+96-6
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ class TasksManager:
100100
"masked-im": "AutoModelForMaskedImageModeling",
101101
"semantic-segmentation": "AutoModelForSemanticSegmentation",
102102
"speech2seq-lm": "AutoModelForSpeechSeq2Seq",
103+
"audio-classification": "AutoModelForAudioClassification",
104+
"audio-frame-classification": "AutoModelForAudioFrameClassification",
105+
"audio-ctc": "AutoModelForCTC",
106+
"audio-xvector": "AutoModelForAudioXVector",
103107
"stable-diffusion": "StableDiffusionPipeline",
104108
}
105109
if is_tf_available():
@@ -130,11 +134,20 @@ class TasksManager:
130134
"masked-im": "transformers",
131135
"semantic-segmentation": "transformers",
132136
"speech2seq-lm": "transformers",
137+
"audio-ctc": "transformers",
138+
"audio-classification": "transformers",
139+
"audio-frame-classification": "transformers",
140+
"audio-xvector": "transformers",
133141
"stable-diffusion": "diffusers",
134142
}
135143

136144
# Set of model topologies we support associated to the tasks supported by each topology and the factory
137145
_SUPPORTED_MODEL_TYPE = {
146+
"audio-spectrogram-transformer": supported_tasks_mapping(
147+
"default",
148+
"audio-classification",
149+
onnx="ASTOnnxConfig",
150+
),
138151
"albert": supported_tasks_mapping(
139152
"default",
140153
"masked-lm",
@@ -273,6 +286,14 @@ class TasksManager:
273286
# "semantic-segmentation",
274287
onnx="Data2VecVisionOnnxConfig",
275288
),
289+
"data2vec-audio": supported_tasks_mapping(
290+
"default",
291+
"audio-ctc",
292+
"audio-classification",
293+
"audio-frame-classification",
294+
"audio-xvector",
295+
onnx="Data2VecAudioOnnxConfig",
296+
),
276297
"deberta": supported_tasks_mapping(
277298
"default",
278299
"masked-lm",
@@ -356,6 +377,12 @@ class TasksManager:
356377
"default",
357378
onnx="GroupViTOnnxConfig",
358379
),
380+
"hubert": supported_tasks_mapping(
381+
"default",
382+
"audio-ctc",
383+
"audio-classification",
384+
onnx="HubertOnnxConfig",
385+
),
359386
"ibert": supported_tasks_mapping(
360387
"default",
361388
"masked-lm",
@@ -423,6 +450,12 @@ class TasksManager:
423450
"question-answering",
424451
onnx="MBartOnnxConfig",
425452
),
453+
# TODO: enable once the missing operator is supported.
454+
# "mctct": supported_tasks_mapping(
455+
# "default",
456+
# "audio-ctc",
457+
# onnx="MCTCTOnnxConfig",
458+
# ),
426459
"mobilebert": supported_tasks_mapping(
427460
"default",
428461
"masked-lm",
@@ -521,6 +554,25 @@ class TasksManager:
521554
"semantic-segmentation",
522555
onnx="SegformerOnnxConfig",
523556
),
557+
"sew": supported_tasks_mapping(
558+
"default",
559+
"audio-ctc",
560+
"audio-classification",
561+
onnx="SEWOnnxConfig",
562+
),
563+
"sew-d": supported_tasks_mapping(
564+
"default",
565+
"audio-ctc",
566+
"audio-classification",
567+
onnx="SEWDOnnxConfig",
568+
),
569+
"speech-to-text": supported_tasks_mapping(
570+
"default",
571+
"default-with-past",
572+
"speech2seq-lm",
573+
"speech2seq-lm-with-past",
574+
onnx="Speech2TextOnnxConfig",
575+
),
524576
"squeezebert": supported_tasks_mapping(
525577
"default",
526578
"masked-lm",
@@ -530,6 +582,12 @@ class TasksManager:
530582
"question-answering",
531583
onnx="SqueezeBertOnnxConfig",
532584
),
585+
"swin": supported_tasks_mapping(
586+
"default",
587+
"image-classification",
588+
"masked-im",
589+
onnx="SwinOnnxConfig",
590+
),
533591
"t5": supported_tasks_mapping(
534592
"default",
535593
"default-with-past",
@@ -541,11 +599,49 @@ class TasksManager:
541599
"semantic-segmentation",
542600
onnx="UNetOnnxConfig",
543601
),
602+
"unispeech": supported_tasks_mapping(
603+
"default",
604+
"audio-ctc",
605+
"audio-classification",
606+
onnx="UniSpeechOnnxConfig",
607+
),
608+
"unispeech-sat": supported_tasks_mapping(
609+
"default",
610+
"audio-ctc",
611+
"audio-classification",
612+
"audio-frame-classification",
613+
"audio-xvector",
614+
onnx="UniSpeechSATOnnxConfig",
615+
),
544616
"vae": supported_tasks_mapping(
545617
"semantic-segmentation",
546618
onnx="VaeOnnxConfig",
547619
),
548620
"vit": supported_tasks_mapping("default", "image-classification", "masked-im", onnx="ViTOnnxConfig"),
621+
"wavlm": supported_tasks_mapping(
622+
"default",
623+
"audio-ctc",
624+
"audio-classification",
625+
"audio-frame-classification",
626+
"audio-xvector",
627+
onnx="WavLMOnnxConfig",
628+
),
629+
"wav2vec2": supported_tasks_mapping(
630+
"default",
631+
"audio-ctc",
632+
"audio-classification",
633+
"audio-frame-classification",
634+
"audio-xvector",
635+
onnx="Wav2Vec2OnnxConfig",
636+
),
637+
"wav2vec2-conformer": supported_tasks_mapping(
638+
"default",
639+
"audio-ctc",
640+
"audio-classification",
641+
"audio-frame-classification",
642+
"audio-xvector",
643+
onnx="Wav2Vec2ConformerOnnxConfig",
644+
),
549645
"whisper": supported_tasks_mapping(
550646
"default",
551647
"default-with-past",
@@ -580,12 +676,6 @@ class TasksManager:
580676
"object-detection",
581677
onnx="YolosOnnxConfig",
582678
),
583-
"swin": supported_tasks_mapping(
584-
"default",
585-
"image-classification",
586-
"masked-im",
587-
onnx="SwinOnnxConfig",
588-
),
589679
}
590680
_UNSUPPORTED_CLI_MODEL_TYPE = {"unet", "vae", "clip-text-model"}
591681
_SUPPORTED_CLI_MODEL_TYPE = set(_SUPPORTED_MODEL_TYPE.keys()) - _UNSUPPORTED_CLI_MODEL_TYPE

0 commit comments

Comments
 (0)