Skip to content

Commit 414afab

Browse files
authored
Add support for Moonshine ONNX export (& seq2seq models with non-legacy cache & Tensor.repeat_interleave) (#2162)
* Add moonshine ONNX config * Remove use_cache_position for whisper exports * Patch torch repeat_interleave during export * Add support for exporting models with non-legacy caches * Formatting * Re-use model patcher for seq2seq models * Add moonshine unit tests * Formatting * When tracing, repeats passed as an int will be turned into a tensor of rank 0. * Fix failing unit test on 4.45.1 CI. Confirmed it works above 4.46 too.
1 parent 27dae50 commit 414afab

File tree

4 files changed

+103
-20
lines changed

4 files changed

+103
-20
lines changed

optimum/exporters/onnx/model_configs.py

+30-3
Original file line numberDiff line numberDiff line change
@@ -1782,6 +1782,33 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
17821782
return {"input_features": {0: "batch_size", 1: "sequence_classification"}}
17831783

17841784

1785+
class MoonshineOnnxConfig(AudioToTextOnnxConfig):
1786+
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig
1787+
1788+
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::triu' to ONNX opset version 11 is not supported.
1789+
# Support for this operator was added in version 14, try exporting with this version.
1790+
DEFAULT_ONNX_OPSET = 14
1791+
1792+
@property
1793+
def inputs(self) -> Dict[str, Dict[int, str]]:
1794+
common_inputs = {}
1795+
1796+
if self._behavior is not ConfigBehavior.DECODER:
1797+
common_inputs["input_values"] = {0: "batch_size", 1: "num_samples"}
1798+
1799+
if self._behavior is not ConfigBehavior.ENCODER:
1800+
if self.use_past_in_inputs:
1801+
common_inputs["decoder_input_ids"] = {0: "batch_size"}
1802+
self.add_past_key_values(common_inputs, direction="inputs")
1803+
else:
1804+
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
1805+
1806+
if self._behavior is ConfigBehavior.DECODER:
1807+
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}
1808+
1809+
return common_inputs
1810+
1811+
17851812
class WhisperOnnxConfig(AudioToTextOnnxConfig):
17861813
DEFAULT_ONNX_OPSET = 14 # Whisper now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
17871814

@@ -1802,9 +1829,9 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
18021829
if self._behavior is not ConfigBehavior.DECODER:
18031830
common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis.
18041831

1805-
if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs:
1806-
if is_transformers_version(">=", "4.43.0"):
1807-
# since https://github.com/huggingface/transformers/pull/31166
1832+
if is_transformers_version(">=", "4.43.0") and is_transformers_version("<", "4.46.0"):
1833+
# since https://github.com/huggingface/transformers/pull/31166
1834+
if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs:
18081835
common_inputs["cache_position"] = {0: "decoder_sequence_length"}
18091836

18101837
if self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs:

optimum/exporters/onnx/model_patcher.py

+64-17
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,51 @@ def onnx_compatible_unfold(input_tensor, dimension, size, step):
156156
return result
157157

158158

159-
UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)]
159+
# An ONNX-export-compatible version of `tensor.repeat_interleave`.
160+
# Without this, we get the following error: https://github.com/pytorch/pytorch/issues/145100
161+
# NOTE: This implementation is only necessary for export with dynamo=False (dynamo=True works correctly).
162+
# and can be removed once Optimum switches to dynamo-based exports
163+
def onnx_compatible_repeat_interleave(input_tensor, repeats, dim=None):
164+
"""
165+
Custom implementation of torch.repeat_interleave without using torch.repeat_interleave.
166+
167+
Args:
168+
input_tensor (torch.Tensor): The input tensor.
169+
repeats (int or torch.Tensor): The number of repetitions for each element.
170+
dim (int, optional): The dimension along which to repeat. Defaults to None.
171+
172+
Returns:
173+
torch.Tensor: The repeated tensor.
174+
"""
175+
if isinstance(repeats, int) or (torch.is_tensor(repeats) and repeats.dim() == 0):
176+
if dim is None:
177+
return input_tensor.flatten().unsqueeze(1).expand(-1, repeats).flatten()
178+
repeats = torch.full((input_tensor.shape[dim],), repeats, dtype=torch.long, device=input_tensor.device)
179+
180+
if dim is None:
181+
return onnx_compatible_repeat_interleave(input_tensor.flatten(), repeats, 0)
182+
183+
if dim != 0:
184+
input_tensor = input_tensor.transpose(0, dim)
185+
186+
# Create expand mask
187+
max_repeats = repeats.max()
188+
expanded = input_tensor.unsqueeze(1).expand(-1, max_repeats, *input_tensor.shape[1:])
189+
mask = torch.arange(max_repeats, device=input_tensor.device) < repeats.unsqueeze(1)
190+
result = expanded[mask]
191+
192+
if dim != 0:
193+
result = result.transpose(0, dim)
194+
195+
return result
196+
197+
198+
UNSUPPORTED_OPS_PATCHING_SPEC = [
199+
PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold),
200+
PatchingSpec(torch.Tensor, "repeat_interleave", onnx_compatible_repeat_interleave, torch.Tensor.repeat_interleave),
201+
# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
202+
PatchingSpec(torch.Tensor, "__len__", lambda x: x.shape[0], torch.Tensor.__len__),
203+
]
160204
CACHE_PATCHING_SPEC = [PatchingSpec(transformers.cache_utils, "Cache", TraceableCache, transformers.cache_utils.Cache)]
161205

162206

@@ -239,7 +283,7 @@ def patched_forward(*args, **kwargs):
239283
# contains the output names of the model. In the case of Timm classification models, the output
240284
# is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
241285
# match the outputs in order.
242-
filterd_outputs = {}
286+
filtered_outputs = {}
243287
if isinstance(outputs, dict):
244288
for name, value in outputs.items():
245289
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
@@ -248,10 +292,10 @@ def patched_forward(*args, **kwargs):
248292
or (allow_past_in_outputs and name.startswith("past_key_values"))
249293
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
250294
):
251-
filterd_outputs[name] = value
295+
filtered_outputs[name] = value
252296
elif isinstance(outputs, (list, tuple)):
253297
outputs_list = list(config.outputs.keys())
254-
filterd_outputs = dict(zip(outputs_list, outputs))
298+
filtered_outputs = dict(zip(outputs_list, outputs))
255299
else:
256300
if len(config.outputs) > 1:
257301
num_outputs = len(config.outputs)
@@ -261,15 +305,15 @@ def patched_forward(*args, **kwargs):
261305
)
262306
else:
263307
name = list(config.outputs.keys())[0]
264-
filterd_outputs[name] = outputs
308+
filtered_outputs[name] = outputs
265309
name = list(config.outputs.keys())[0]
266-
filterd_outputs[name] = outputs
310+
filtered_outputs[name] = outputs
267311

268312
if is_transformers_version(">=", "4.48"):
269-
if isinstance(filterd_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)):
270-
filterd_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()
313+
if isinstance(filtered_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)):
314+
filtered_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()
271315

272-
return filterd_outputs
316+
return filtered_outputs
273317

274318
self.patched_forward = patched_forward
275319

@@ -325,15 +369,18 @@ def __init__(
325369
if model.config.model_type == "pix2struct" and allow_past_in_outputs:
326370
model.config.text_config.use_cache = True
327371

328-
@functools.wraps(self.orig_forward)
372+
# Re-use the patched forward method from the parent class
373+
self.super_patched_forward = self.patched_forward
374+
375+
@functools.wraps(self.super_patched_forward)
329376
def patched_forward(*args, **kwargs):
330-
signature = inspect.signature(self.orig_forward)
377+
signature = inspect.signature(self.super_patched_forward)
331378
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
332379

333-
outputs = self.orig_forward(*args, **kwargs)
380+
outputs = self.super_patched_forward(*args, **kwargs)
334381

335382
# Filter out cross attention past key values output from the decoder using KV cache, as they are constants.
336-
filterd_outputs = {}
383+
filtered_outputs = {}
337384
for name, value in outputs.items():
338385
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
339386
if (
@@ -346,17 +393,17 @@ def patched_forward(*args, **kwargs):
346393
# Who cares about the encoder outputs in the decoder?
347394
continue
348395
else:
349-
filterd_outputs[name] = value
396+
filtered_outputs[name] = value
350397
else:
351398
if self.real_config._behavior == "monolith" or (
352399
self.real_config._behavior == "decoder"
353400
and (self.real_config.is_merged or not self.real_config.use_past_in_inputs)
354401
):
355-
filterd_outputs[name] = value
402+
filtered_outputs[name] = value
356403
elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs:
357404
# The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one.
358-
filterd_outputs[name] = tuple([v[:2] for v in value])
359-
return filterd_outputs
405+
filtered_outputs[name] = tuple([v[:2] for v in value])
406+
return filtered_outputs
360407

361408
self.patched_forward = patched_forward
362409

optimum/exporters/tasks.py

+7
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,13 @@ class TasksManager:
903903
"token-classification",
904904
onnx="ModernBertOnnxConfig",
905905
),
906+
"moonshine": supported_tasks_mapping(
907+
"feature-extraction",
908+
"feature-extraction-with-past",
909+
"automatic-speech-recognition",
910+
"automatic-speech-recognition-with-past",
911+
onnx="MoonshineOnnxConfig",
912+
),
906913
"mpnet": supported_tasks_mapping(
907914
"feature-extraction",
908915
"fill-mask",

tests/exporters/exporters_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
"mobilenet-v1": "hf-internal-testing/tiny-random-MobileNetV1Model",
131131
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
132132
"modernbert": "hf-internal-testing/tiny-random-ModernBertForMaskedLM",
133+
"moonshine": "hf-internal-testing/tiny-random-MoonshineForConditionalGeneration",
133134
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
134135
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
135136
"mt5": "lewtun/tiny-random-mt5",
@@ -271,6 +272,7 @@
271272
"mobilenet_v2": "google/mobilenet_v2_0.35_96",
272273
"mobilevit": "apple/mobilevit-small",
273274
"modernbert": "answerdotai/ModernBERT-base",
275+
"moonshine": "UsefulSensors/moonshine-tiny",
274276
"mpt": "mosaicml/mpt-7b",
275277
"mt5": "google/mt5-small",
276278
"musicgen": "facebook/musicgen-small",

0 commit comments

Comments
 (0)