20
20
from pathlib import Path
21
21
from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Tuple , Union
22
22
23
- import onnx
24
23
from transformers .generation import GenerationMixin
25
24
from transformers .utils import is_tf_available , is_torch_available
26
25
27
26
from openvino .runtime import Model , save_model
28
27
from openvino .runtime .exceptions import OVTypeError
29
28
from openvino .tools .ovc import convert_model
30
29
from optimum .exporters import TasksManager
31
- from optimum .exporters .onnx .base import OnnxConfig
32
- from optimum .exporters .onnx .convert import check_dummy_inputs_are_allowed
33
- from optimum .exporters .onnx .convert import export_pytorch as export_pytorch_to_onnx
34
- from optimum .exporters .onnx .convert import export_tensorflow as export_tensorflow_onnx
35
30
from optimum .exporters .utils import (
36
31
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs ,
37
32
)
89
84
90
85
91
86
if TYPE_CHECKING :
87
+ from optimum .exporters .onnx .base import OnnxConfig
92
88
from optimum .intel .openvino .configuration import OVConfig
93
89
94
90
@@ -115,7 +111,7 @@ def _save_model(
115
111
path : str ,
116
112
ov_config : Optional ["OVConfig" ] = None ,
117
113
library_name : Optional [str ] = None ,
118
- config : OnnxConfig = None ,
114
+ config : " OnnxConfig" = None ,
119
115
):
120
116
compress_to_fp16 = ov_config is not None and ov_config .dtype == "fp16"
121
117
model = _add_version_info_to_model (model , library_name )
@@ -129,7 +125,7 @@ def _save_model(
129
125
130
126
def export (
131
127
model : Union ["PreTrainedModel" , "TFPreTrainedModel" , "ModelMixin" , "DiffusionPipeline" ],
132
- config : OnnxConfig ,
128
+ config : " OnnxConfig" ,
133
129
output : Path ,
134
130
opset : Optional [int ] = None ,
135
131
device : str = "cpu" ,
@@ -212,7 +208,7 @@ def export(
212
208
213
209
def export_tensorflow (
214
210
model : Union ["PreTrainedModel" , "ModelMixin" ],
215
- config : OnnxConfig ,
211
+ config : " OnnxConfig" ,
216
212
opset : int ,
217
213
output : Path ,
218
214
ov_config : Optional ["OVConfig" ] = None ,
@@ -232,6 +228,8 @@ def export_tensorflow(
232
228
output_names: list of output names from ONNX configuration
233
229
bool: True if the model was exported successfully.
234
230
"""
231
+ from optimum .exporters .onnx .convert import export_tensorflow as export_tensorflow_onnx
232
+
235
233
onnx_path = Path (output ).with_suffix (".onnx" )
236
234
input_names , output_names = export_tensorflow_onnx (model , config , opset , onnx_path )
237
235
ov_model = convert_model (str (onnx_path ))
@@ -252,7 +250,7 @@ def export_tensorflow(
252
250
253
251
def export_pytorch_via_onnx (
254
252
model : Union ["PreTrainedModel" , "ModelMixin" ],
255
- config : OnnxConfig ,
253
+ config : " OnnxConfig" ,
256
254
opset : int ,
257
255
output : Path ,
258
256
device : str = "cpu" ,
@@ -289,6 +287,8 @@ def export_pytorch_via_onnx(
289
287
"""
290
288
import torch
291
289
290
+ from optimum .exporters .onnx .convert import export_pytorch as export_pytorch_to_onnx
291
+
292
292
output = Path (output )
293
293
orig_torch_onnx_export = torch .onnx .export
294
294
torch .onnx .export = functools .partial (orig_torch_onnx_export , do_constant_folding = False )
@@ -317,7 +317,7 @@ def export_pytorch_via_onnx(
317
317
318
318
def export_pytorch (
319
319
model : Union ["PreTrainedModel" , "ModelMixin" ],
320
- config : OnnxConfig ,
320
+ config : " OnnxConfig" ,
321
321
opset : int ,
322
322
output : Path ,
323
323
device : str = "cpu" ,
@@ -359,6 +359,8 @@ def export_pytorch(
359
359
import torch
360
360
from torch .utils ._pytree import tree_map
361
361
362
+ from optimum .exporters .onnx .convert import check_dummy_inputs_are_allowed
363
+
362
364
logger .info (f"Using framework PyTorch: { torch .__version__ } " )
363
365
output = Path (output )
364
366
@@ -878,6 +880,8 @@ def _add_version_info_to_model(model: Model, library_name: Optional[str] = None)
878
880
model .set_rt_info (_nncf_version , ["optimum" , "nncf_version" ])
879
881
input_model = rt_info ["conversion_parameters" ].get ("input_model" , None )
880
882
if input_model is not None and "onnx" in input_model .value :
883
+ import onnx
884
+
881
885
model .set_rt_info (onnx .__version__ , ["optimum" , "onnx_version" ])
882
886
883
887
except Exception :
0 commit comments