@@ -415,80 +415,43 @@ def export_pytorch(
415
415
dummy_inputs = config .rename_ambiguous_inputs (dummy_inputs )
416
416
dummy_inputs , dict_inputs = remove_none_from_dummy_inputs (dummy_inputs )
417
417
418
- try :
419
- # TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
420
- # while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
421
- # To handle it, additional wrapper on patcher forward applied.
422
- # model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
423
- patcher = config .patch_model_for_export (model , model_kwargs = model_kwargs )
424
- patched_forward = patcher .patched_forward
425
-
426
- @functools .wraps (patched_forward )
427
- def ts_patched_forward (* args , ** kwargs ):
428
- for i in range (len (dict_inputs )):
429
- input_name , keys = dict_inputs [i ]
430
- tuple_input = kwargs [input_name ]
431
- input_dict = dict (zip (keys , tuple_input ))
432
- kwargs [input_name ] = input_dict
433
- outputs = patched_forward (* args , ** kwargs )
434
- return tuple ([value if not isinstance (value , list ) else tuple (value ) for value in outputs .values ()])
435
-
436
- patcher .patched_forward = ts_patched_forward
437
-
438
- ts_decoder_kwargs = {}
439
- if library_name == "diffusers" and is_openvino_version (">=" , "2025.0" ):
440
- ts_decoder_kwargs ["trace_kwargs" ] = {"check_trace" : False }
441
-
442
- with patcher :
443
- if patch_16bit_model :
444
- from openvino .frontend .pytorch .patch_model import __make_16bit_traceable
445
-
446
- __make_16bit_traceable (model )
447
- check_dummy_inputs_are_allowed (model , dummy_inputs )
448
- input_info = _get_input_info (model , config , dummy_inputs )
449
- ts_decoder = TorchScriptPythonDecoder (model , example_input = dummy_inputs , ** ts_decoder_kwargs )
450
- ov_model = convert_model (
451
- ts_decoder ,
452
- example_input = dummy_inputs ,
453
- input = [(item .shape , item .type ) for item in input_info ],
454
- )
455
-
456
- except Exception as ex :
457
- logger .warning (f"Export model to OpenVINO directly failed with: \n { ex } .\n Model will be exported to ONNX" )
458
-
459
- if stateful :
460
- # cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly
461
- # TODO: Implement stateful for ONNX path as well, not doing it right now because of lack of validation
462
- logger .warning (
463
- "[ WARNING ] Making stateful models is not supported when exporting to ONNX as an intermediate step. "
464
- "A stateless model will be exported instead. It may result in sub-optimal inference performance."
465
- "Provide a model that can be converted to OpenVINO without fallback to ONNX conversion path."
466
- )
467
-
418
+ # TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
419
+ # while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
420
+ # To handle it, additional wrapper on patcher forward applied.
421
+ # model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
422
+ patcher = config .patch_model_for_export (model , model_kwargs = model_kwargs )
423
+ patched_forward = patcher .patched_forward
424
+
425
+ @functools .wraps (patched_forward )
426
+ def ts_patched_forward (* args , ** kwargs ):
427
+ for i in range (len (dict_inputs )):
428
+ input_name , keys = dict_inputs [i ]
429
+ tuple_input = kwargs [input_name ]
430
+ input_dict = dict (zip (keys , tuple_input ))
431
+ kwargs [input_name ] = input_dict
432
+ outputs = patched_forward (* args , ** kwargs )
433
+ return tuple ([value if not isinstance (value , list ) else tuple (value ) for value in outputs .values ()])
434
+
435
+ patcher .patched_forward = ts_patched_forward
436
+
437
+ ts_decoder_kwargs = {}
438
+ if library_name == "diffusers" and is_openvino_version (">=" , "2025.0" ):
439
+ ts_decoder_kwargs ["trace_kwargs" ] = {"check_trace" : False }
440
+
441
+ with patcher :
468
442
if patch_16bit_model :
469
- from openvino .frontend .pytorch .patch_model import unpatch_model
470
-
471
- unpatch_model (model , "_openvino_module_extension_patch_orig_forward" )
472
- for m in model .modules ():
473
- if any (p .dtype in [torch .float16 , torch .bfloat16 ] for p in m .parameters (False )) or any (
474
- b .dtype in [torch .float16 , torch .bfloat16 ] for b in m .buffers (False )
475
- ):
476
- m .float ()
477
-
478
- return export_pytorch_via_onnx (
479
- model ,
480
- config ,
481
- opset ,
482
- output ,
483
- device ,
484
- input_shapes ,
485
- model_kwargs ,
486
- ov_config = ov_config ,
487
- library_name = library_name ,
443
+ from openvino .frontend .pytorch .patch_model import __make_16bit_traceable
444
+
445
+ __make_16bit_traceable (model )
446
+ check_dummy_inputs_are_allowed (model , dummy_inputs )
447
+ input_info = _get_input_info (model , config , dummy_inputs )
448
+ ts_decoder = TorchScriptPythonDecoder (model , example_input = dummy_inputs , ** ts_decoder_kwargs )
449
+ ov_model = convert_model (
450
+ ts_decoder ,
451
+ example_input = dummy_inputs ,
452
+ input = [(item .shape , item .type ) for item in input_info ],
488
453
)
489
454
490
- ov_model .validate_nodes_and_infer_types () # TODO: remove as unnecessary validation?
491
-
492
455
output_names = list (config .outputs .keys ())
493
456
for idx , out_tensor in enumerate (ov_model .outputs ):
494
457
if idx < len (output_names ):
0 commit comments