|
23 | 23 | from transformers.generation import GenerationMixin
|
24 | 24 | from transformers.utils import is_tf_available, is_torch_available
|
25 | 25 |
|
| 26 | +from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder |
26 | 27 | from openvino.runtime import Model, save_model
|
27 | 28 | from openvino.runtime.exceptions import OVTypeError
|
28 | 29 | from openvino.tools.ovc import convert_model
|
|
46 | 47 | is_openvino_tokenizers_version,
|
47 | 48 | is_tokenizers_version,
|
48 | 49 | is_transformers_version,
|
| 50 | + is_openvino_version, |
49 | 51 | )
|
50 | 52 | from optimum.utils import DEFAULT_DUMMY_SHAPES, is_diffusers_available
|
51 | 53 |
|
@@ -427,15 +429,20 @@ def ts_patched_forward(*args, **kwargs):
|
427 | 429 |
|
428 | 430 | patcher.patched_forward = ts_patched_forward
|
429 | 431 |
|
| 432 | + decoder_kwargs = {} |
| 433 | + if library_name == "diffusers" and is_openvino_version(">=", "2025.0"): |
| 434 | + decoder_kwargs["trace_kwargs"] = {"check_trace": False} |
| 435 | + |
430 | 436 | with patcher:
|
431 | 437 | if patch_16bit_model:
|
432 | 438 | from openvino.frontend.pytorch.patch_model import __make_16bit_traceable
|
433 | 439 |
|
434 | 440 | __make_16bit_traceable(model)
|
435 | 441 | check_dummy_inputs_are_allowed(model, dummy_inputs)
|
436 | 442 | input_info = _get_input_info(model, config, dummy_inputs)
|
| 443 | + decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **decoder_kwargs) |
437 | 444 | ov_model = convert_model(
|
438 |
| - model, |
| 445 | + decoder, |
439 | 446 | example_input=dummy_inputs,
|
440 | 447 | input=[(item.shape, item.type) for item in input_info],
|
441 | 448 | )
|
|
0 commit comments