@@ -156,7 +156,51 @@ def onnx_compatible_unfold(input_tensor, dimension, size, step):
156
156
return result
157
157
158
158
159
- UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec (torch .Tensor , "unfold" , onnx_compatible_unfold , torch .Tensor .unfold )]
159
+ # An ONNX-export-compatible version of `tensor.repeat_interleave`.
160
+ # Without this, we get the following error: https://github.com/pytorch/pytorch/issues/145100
161
+ # NOTE: This implementation is only necessary for export with dynamo=False (dynamo=True works correctly).
162
+ # and can be removed once Optimum switches to dynamo-based exports
163
+ def onnx_compatible_repeat_interleave (input_tensor , repeats , dim = None ):
164
+ """
165
+ Custom implementation of torch.repeat_interleave without using torch.repeat_interleave.
166
+
167
+ Args:
168
+ input_tensor (torch.Tensor): The input tensor.
169
+ repeats (int or torch.Tensor): The number of repetitions for each element.
170
+ dim (int, optional): The dimension along which to repeat. Defaults to None.
171
+
172
+ Returns:
173
+ torch.Tensor: The repeated tensor.
174
+ """
175
+ if isinstance (repeats , int ) or (torch .is_tensor (repeats ) and repeats .dim () == 0 ):
176
+ if dim is None :
177
+ return input_tensor .flatten ().unsqueeze (1 ).expand (- 1 , repeats ).flatten ()
178
+ repeats = torch .full ((input_tensor .shape [dim ],), repeats , dtype = torch .long , device = input_tensor .device )
179
+
180
+ if dim is None :
181
+ return onnx_compatible_repeat_interleave (input_tensor .flatten (), repeats , 0 )
182
+
183
+ if dim != 0 :
184
+ input_tensor = input_tensor .transpose (0 , dim )
185
+
186
+ # Create expand mask
187
+ max_repeats = repeats .max ()
188
+ expanded = input_tensor .unsqueeze (1 ).expand (- 1 , max_repeats , * input_tensor .shape [1 :])
189
+ mask = torch .arange (max_repeats , device = input_tensor .device ) < repeats .unsqueeze (1 )
190
+ result = expanded [mask ]
191
+
192
+ if dim != 0 :
193
+ result = result .transpose (0 , dim )
194
+
195
+ return result
196
+
197
+
198
+ UNSUPPORTED_OPS_PATCHING_SPEC = [
199
+ PatchingSpec (torch .Tensor , "unfold" , onnx_compatible_unfold , torch .Tensor .unfold ),
200
+ PatchingSpec (torch .Tensor , "repeat_interleave" , onnx_compatible_repeat_interleave , torch .Tensor .repeat_interleave ),
201
+ # TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
202
+ PatchingSpec (torch .Tensor , "__len__" , lambda x : x .shape [0 ], torch .Tensor .__len__ ),
203
+ ]
160
204
CACHE_PATCHING_SPEC = [PatchingSpec (transformers .cache_utils , "Cache" , TraceableCache , transformers .cache_utils .Cache )]
161
205
162
206
@@ -239,7 +283,7 @@ def patched_forward(*args, **kwargs):
239
283
# contains the output names of the model. In the case of Timm classification models, the output
240
284
# is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
241
285
# match the outputs in order.
242
- filterd_outputs = {}
286
+ filtered_outputs = {}
243
287
if isinstance (outputs , dict ):
244
288
for name , value in outputs .items ():
245
289
onnx_output_name = config .torch_to_onnx_output_map .get (name , name )
@@ -248,10 +292,10 @@ def patched_forward(*args, **kwargs):
248
292
or (allow_past_in_outputs and name .startswith ("past_key_values" ))
249
293
or any (key .startswith (onnx_output_name ) for key in config .outputs .keys ())
250
294
):
251
- filterd_outputs [name ] = value
295
+ filtered_outputs [name ] = value
252
296
elif isinstance (outputs , (list , tuple )):
253
297
outputs_list = list (config .outputs .keys ())
254
- filterd_outputs = dict (zip (outputs_list , outputs ))
298
+ filtered_outputs = dict (zip (outputs_list , outputs ))
255
299
else :
256
300
if len (config .outputs ) > 1 :
257
301
num_outputs = len (config .outputs )
@@ -261,15 +305,15 @@ def patched_forward(*args, **kwargs):
261
305
)
262
306
else :
263
307
name = list (config .outputs .keys ())[0 ]
264
- filterd_outputs [name ] = outputs
308
+ filtered_outputs [name ] = outputs
265
309
name = list (config .outputs .keys ())[0 ]
266
- filterd_outputs [name ] = outputs
310
+ filtered_outputs [name ] = outputs
267
311
268
312
if is_transformers_version (">=" , "4.48" ):
269
- if isinstance (filterd_outputs .get ("past_key_values" ), (DynamicCache , EncoderDecoderCache )):
270
- filterd_outputs ["past_key_values" ] = outputs ["past_key_values" ].to_legacy_cache ()
313
+ if isinstance (filtered_outputs .get ("past_key_values" ), (DynamicCache , EncoderDecoderCache )):
314
+ filtered_outputs ["past_key_values" ] = outputs ["past_key_values" ].to_legacy_cache ()
271
315
272
- return filterd_outputs
316
+ return filtered_outputs
273
317
274
318
self .patched_forward = patched_forward
275
319
@@ -325,15 +369,18 @@ def __init__(
325
369
if model .config .model_type == "pix2struct" and allow_past_in_outputs :
326
370
model .config .text_config .use_cache = True
327
371
328
- @functools .wraps (self .orig_forward )
372
+ # Re-use the patched forward method from the parent class
373
+ self .super_patched_forward = self .patched_forward
374
+
375
+ @functools .wraps (self .super_patched_forward )
329
376
def patched_forward (* args , ** kwargs ):
330
- signature = inspect .signature (self .orig_forward )
377
+ signature = inspect .signature (self .super_patched_forward )
331
378
args , kwargs = override_arguments (args , kwargs , signature , model_kwargs = self .model_kwargs )
332
379
333
- outputs = self .orig_forward (* args , ** kwargs )
380
+ outputs = self .super_patched_forward (* args , ** kwargs )
334
381
335
382
# Filter out cross attention past key values output from the decoder using KV cache, as they are constants.
336
- filterd_outputs = {}
383
+ filtered_outputs = {}
337
384
for name , value in outputs .items ():
338
385
onnx_output_name = config .torch_to_onnx_output_map .get (name , name )
339
386
if (
@@ -346,17 +393,17 @@ def patched_forward(*args, **kwargs):
346
393
# Who cares about the encoder outputs in the decoder?
347
394
continue
348
395
else :
349
- filterd_outputs [name ] = value
396
+ filtered_outputs [name ] = value
350
397
else :
351
398
if self .real_config ._behavior == "monolith" or (
352
399
self .real_config ._behavior == "decoder"
353
400
and (self .real_config .is_merged or not self .real_config .use_past_in_inputs )
354
401
):
355
- filterd_outputs [name ] = value
402
+ filtered_outputs [name ] = value
356
403
elif self .real_config ._behavior == "decoder" and self .real_config .use_past_in_inputs :
357
404
# The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one.
358
- filterd_outputs [name ] = tuple ([v [:2 ] for v in value ])
359
- return filterd_outputs
405
+ filtered_outputs [name ] = tuple ([v [:2 ] for v in value ])
406
+ return filtered_outputs
360
407
361
408
self .patched_forward = patched_forward
362
409
0 commit comments