14
14
15
15
import functools
16
16
import gc
17
+ import inspect
17
18
import logging
18
19
import os
19
20
from pathlib import Path
43
44
clear_class_registry ,
44
45
flattenize_inputs ,
45
46
get_input_shapes ,
47
+ remove_none_from_dummy_inputs ,
46
48
)
47
49
48
50
@@ -370,9 +372,29 @@ def export_pytorch(
370
372
)
371
373
372
374
dummy_inputs = config .rename_ambiguous_inputs (dummy_inputs )
375
+ dummy_inputs , dict_inputs = remove_none_from_dummy_inputs (dummy_inputs )
373
376
374
377
try :
375
- with config .patch_model_for_export (model , model_kwargs = model_kwargs ):
378
+ # TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
379
+ # while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
380
+ # To handle it, additional wrapper on patcher forward applied.
381
+ # model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
382
+ patcher = config .patch_model_for_export (model , model_kwargs = model_kwargs )
383
+ patched_forward = patcher .patched_forward
384
+
385
+ @functools .wraps (patched_forward )
386
+ def ts_patched_forward (* args , ** kwargs ):
387
+ for i in range (len (dict_inputs )):
388
+ input_name , keys = dict_inputs [i ]
389
+ tuple_input = kwargs [input_name ]
390
+ input_dict = dict (zip (keys , tuple_input ))
391
+ kwargs [input_name ] = input_dict
392
+ outputs = patched_forward (* args , ** kwargs )
393
+ return tuple (outputs .values ())
394
+
395
+ patcher .patched_forward = ts_patched_forward
396
+
397
+ with patcher :
376
398
check_dummy_inputs_are_allowed (model , dummy_inputs )
377
399
inputs = config .ordered_inputs (model )
378
400
input_names = list (inputs .keys ())
@@ -404,7 +426,8 @@ def export_pytorch(
404
426
compression_ratio = compression_ratio ,
405
427
)
406
428
407
- ordered_dummy_inputs = {param : dummy_inputs [param ] for param in inputs if param in dummy_inputs }
429
+ sig = inspect .signature (model .forward ) if hasattr (model , "forward" ) else inspect .signature (model .call )
430
+ ordered_dummy_inputs = {param : dummy_inputs [param ] for param in sig .parameters if param in dummy_inputs }
408
431
ordered_input_names = list (inputs )
409
432
flatten_inputs = flattenize_inputs (ordered_dummy_inputs .values ())
410
433
ov_model .validate_nodes_and_infer_types ()
@@ -418,7 +441,6 @@ def export_pytorch(
418
441
inp_data = flatten_inputs [idx ]
419
442
static_shape = PartialShape (inp_data .shape )
420
443
dims = inputs [input_name ]
421
-
422
444
for dim in dims :
423
445
static_shape [dim ] = - 1
424
446
inp_tensor .get_node ().set_partial_shape (static_shape )
0 commit comments