Skip to content

Commit fb1910e

Browse files
committed
fix export
1 parent fa97edd commit fb1910e

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

optimum/exporters/openvino/convert.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import functools
1616
import gc
17+
import inspect
1718
import logging
1819
import os
1920
from pathlib import Path
@@ -43,6 +44,7 @@
4344
clear_class_registry,
4445
flattenize_inputs,
4546
get_input_shapes,
47+
remove_none_from_dummy_inputs,
4648
)
4749

4850

@@ -370,9 +372,29 @@ def export_pytorch(
370372
)
371373

372374
dummy_inputs = config.rename_ambiguous_inputs(dummy_inputs)
375+
dummy_inputs, dict_inputs = remove_none_from_dummy_inputs(dummy_inputs)
373376

374377
try:
375-
with config.patch_model_for_export(model, model_kwargs=model_kwargs):
378+
# TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
379+
# while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
380+
# To handle it, additional wrapper on patcher forward applied.
381+
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
382+
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
383+
patched_forward = patcher.patched_forward
384+
385+
@functools.wraps(patched_forward)
386+
def ts_patched_forward(*args, **kwargs):
387+
for i in range(len(dict_inputs)):
388+
input_name, keys = dict_inputs[i]
389+
tuple_input = kwargs[input_name]
390+
input_dict = dict(zip(keys, tuple_input))
391+
kwargs[input_name] = input_dict
392+
outputs = patched_forward(*args, **kwargs)
393+
return tuple(outputs.values())
394+
395+
patcher.patched_forward = ts_patched_forward
396+
397+
with patcher:
376398
check_dummy_inputs_are_allowed(model, dummy_inputs)
377399
inputs = config.ordered_inputs(model)
378400
input_names = list(inputs.keys())
@@ -404,7 +426,8 @@ def export_pytorch(
404426
compression_ratio=compression_ratio,
405427
)
406428

407-
ordered_dummy_inputs = {param: dummy_inputs[param] for param in inputs if param in dummy_inputs}
429+
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
430+
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
408431
ordered_input_names = list(inputs)
409432
flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values())
410433
ov_model.validate_nodes_and_infer_types()
@@ -418,7 +441,6 @@ def export_pytorch(
418441
inp_data = flatten_inputs[idx]
419442
static_shape = PartialShape(inp_data.shape)
420443
dims = inputs[input_name]
421-
422444
for dim in dims:
423445
static_shape[dim] = -1
424446
inp_tensor.get_node().set_partial_shape(static_shape)

0 commit comments

Comments
 (0)