Skip to content

Commit f6d8365

Browse files
committed
try to remove onnx fallback
1 parent fe10aaa commit f6d8365

File tree

1 file changed

+34
-71
lines changed

1 file changed

+34
-71
lines changed

optimum/exporters/openvino/convert.py

+34-71
Original file line numberDiff line numberDiff line change
@@ -415,80 +415,43 @@ def export_pytorch(
415415
dummy_inputs = config.rename_ambiguous_inputs(dummy_inputs)
416416
dummy_inputs, dict_inputs = remove_none_from_dummy_inputs(dummy_inputs)
417417

418-
try:
419-
# TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
420-
# while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
421-
# To handle it, additional wrapper on patcher forward applied.
422-
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
423-
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
424-
patched_forward = patcher.patched_forward
425-
426-
@functools.wraps(patched_forward)
427-
def ts_patched_forward(*args, **kwargs):
428-
for i in range(len(dict_inputs)):
429-
input_name, keys = dict_inputs[i]
430-
tuple_input = kwargs[input_name]
431-
input_dict = dict(zip(keys, tuple_input))
432-
kwargs[input_name] = input_dict
433-
outputs = patched_forward(*args, **kwargs)
434-
return tuple([value if not isinstance(value, list) else tuple(value) for value in outputs.values()])
435-
436-
patcher.patched_forward = ts_patched_forward
437-
438-
ts_decoder_kwargs = {}
439-
if library_name == "diffusers" and is_openvino_version(">=", "2025.0"):
440-
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}
441-
442-
with patcher:
443-
if patch_16bit_model:
444-
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable
445-
446-
__make_16bit_traceable(model)
447-
check_dummy_inputs_are_allowed(model, dummy_inputs)
448-
input_info = _get_input_info(model, config, dummy_inputs)
449-
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
450-
ov_model = convert_model(
451-
ts_decoder,
452-
example_input=dummy_inputs,
453-
input=[(item.shape, item.type) for item in input_info],
454-
)
455-
456-
except Exception as ex:
457-
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")
458-
459-
if stateful:
460-
# cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly
461-
# TODO: Implement stateful for ONNX path as well, not doing it right now because of lack of validation
462-
logger.warning(
463-
"[ WARNING ] Making stateful models is not supported when exporting to ONNX as an intermediate step. "
464-
"A stateless model will be exported instead. It may result in sub-optimal inference performance."
465-
"Provide a model that can be converted to OpenVINO without fallback to ONNX conversion path."
466-
)
467-
418+
# TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
419+
# while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
420+
# To handle it, additional wrapper on patcher forward applied.
421+
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
422+
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
423+
patched_forward = patcher.patched_forward
424+
425+
@functools.wraps(patched_forward)
426+
def ts_patched_forward(*args, **kwargs):
427+
for i in range(len(dict_inputs)):
428+
input_name, keys = dict_inputs[i]
429+
tuple_input = kwargs[input_name]
430+
input_dict = dict(zip(keys, tuple_input))
431+
kwargs[input_name] = input_dict
432+
outputs = patched_forward(*args, **kwargs)
433+
return tuple([value if not isinstance(value, list) else tuple(value) for value in outputs.values()])
434+
435+
patcher.patched_forward = ts_patched_forward
436+
437+
ts_decoder_kwargs = {}
438+
if library_name == "diffusers" and is_openvino_version(">=", "2025.0"):
439+
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}
440+
441+
with patcher:
468442
if patch_16bit_model:
469-
from openvino.frontend.pytorch.patch_model import unpatch_model
470-
471-
unpatch_model(model, "_openvino_module_extension_patch_orig_forward")
472-
for m in model.modules():
473-
if any(p.dtype in [torch.float16, torch.bfloat16] for p in m.parameters(False)) or any(
474-
b.dtype in [torch.float16, torch.bfloat16] for b in m.buffers(False)
475-
):
476-
m.float()
477-
478-
return export_pytorch_via_onnx(
479-
model,
480-
config,
481-
opset,
482-
output,
483-
device,
484-
input_shapes,
485-
model_kwargs,
486-
ov_config=ov_config,
487-
library_name=library_name,
443+
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable
444+
445+
__make_16bit_traceable(model)
446+
check_dummy_inputs_are_allowed(model, dummy_inputs)
447+
input_info = _get_input_info(model, config, dummy_inputs)
448+
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
449+
ov_model = convert_model(
450+
ts_decoder,
451+
example_input=dummy_inputs,
452+
input=[(item.shape, item.type) for item in input_info],
488453
)
489454

490-
ov_model.validate_nodes_and_infer_types() # TODO: remove as unnecessary validation?
491-
492455
output_names = list(config.outputs.keys())
493456
for idx, out_tensor in enumerate(ov_model.outputs):
494457
if idx < len(output_names):

0 commit comments

Comments
 (0)