15
15
import logging
16
16
import os
17
17
from pathlib import Path
18
- from typing import Any , Callable , Dict , Optional , Union
18
+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Union
19
19
20
20
from requests .exceptions import ConnectionError as RequestsConnectionError
21
21
from transformers import AutoTokenizer
22
22
23
23
from optimum .exporters import TasksManager
24
- from optimum .exporters .onnx import __main__ as optimum_main
25
24
from optimum .exporters .onnx .base import OnnxConfig , OnnxConfigWithPast
25
+ from optimum .exporters .onnx .utils import (
26
+ _get_submodels_for_export_encoder_decoder ,
27
+ _get_submodels_for_export_stable_diffusion ,
28
+ get_encoder_decoder_models_for_export ,
29
+ get_sam_models_for_export ,
30
+ get_stable_diffusion_models_for_export ,
31
+ )
26
32
from optimum .utils import DEFAULT_DUMMY_SHAPES
27
33
from optimum .utils .save_utils import maybe_load_preprocessors , maybe_save_preprocessors
28
34
31
37
from .convert import export_models
32
38
33
39
40
+ if TYPE_CHECKING :
41
+ from transformers import PreTrainedModel , TFPreTrainedModel
42
+
43
+
34
44
OV_XML_FILE_NAME = "openvino_model.xml"
35
45
36
46
_MAX_UNCOMPRESSED_SIZE = 1e9
37
47
38
48
logger = logging .getLogger (__name__ )
39
49
40
50
51
+ def _get_submodels_and_export_configs (
52
+ model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
53
+ task : str ,
54
+ custom_onnx_configs : Dict ,
55
+ custom_architecture : bool ,
56
+ _variant : str ,
57
+ int_dtype : str = "int64" ,
58
+ float_dtype : str = "fp32" ,
59
+ fn_get_submodels : Optional [Callable ] = None ,
60
+ preprocessors : Optional [List [Any ]] = None ,
61
+ no_position_ids : bool = False ,
62
+ ):
63
+ is_stable_diffusion = "stable-diffusion" in task
64
+ if not custom_architecture :
65
+ if is_stable_diffusion :
66
+ onnx_config = None
67
+ models_and_onnx_configs = get_stable_diffusion_models_for_export (
68
+ model , int_dtype = int_dtype , float_dtype = float_dtype
69
+ )
70
+ else :
71
+ onnx_config_constructor = TasksManager .get_exporter_config_constructor (
72
+ model = model , exporter = "openvino" , task = task
73
+ )
74
+ onnx_config_kwargs = {}
75
+ if task .startswith ("text-generation" ) and no_position_ids :
76
+ onnx_config_kwargs ["no_position_ids" ] = no_position_ids
77
+
78
+ onnx_config = onnx_config_constructor (
79
+ model .config ,
80
+ int_dtype = int_dtype ,
81
+ float_dtype = float_dtype ,
82
+ preprocessors = preprocessors ,
83
+ ** onnx_config_kwargs ,
84
+ )
85
+
86
+ onnx_config .variant = _variant
87
+ all_variants = "\n " .join (
88
+ [f"\t - { name } : { description } " for name , description in onnx_config .VARIANTS .items ()]
89
+ )
90
+ logger .info (f"Using the export variant { onnx_config .variant } . Available variants are:\n { all_variants } " )
91
+
92
+ if model .config .is_encoder_decoder and task .startswith (TasksManager ._ENCODER_DECODER_TASKS ):
93
+ models_and_onnx_configs = get_encoder_decoder_models_for_export (model , onnx_config )
94
+ elif task .startswith ("text-generation" ):
95
+ model = patch_decoder_attention_mask (model )
96
+ onnx_config_constructor = TasksManager .get_exporter_config_constructor (
97
+ model = model , exporter = "openvino" , task = task
98
+ )
99
+ onnx_config = onnx_config_constructor (model .config )
100
+ models_and_onnx_configs = {"model" : (model , onnx_config )}
101
+ elif model .config .model_type == "sam" :
102
+ models_and_onnx_configs = get_sam_models_for_export (model , onnx_config )
103
+ else :
104
+ models_and_onnx_configs = {"model" : (model , onnx_config )}
105
+
106
+ # When specifying custom ONNX configs for supported transformers architectures, we do
107
+ # not force to specify a custom ONNX config for each submodel.
108
+ for key , custom_onnx_config in custom_onnx_configs .items ():
109
+ models_and_onnx_configs [key ] = (models_and_onnx_configs [key ][0 ], custom_onnx_config )
110
+ else :
111
+ onnx_config = None
112
+ submodels_for_export = None
113
+ models_and_onnx_configs = {}
114
+
115
+ if fn_get_submodels is not None :
116
+ submodels_for_export = fn_get_submodels (model )
117
+ else :
118
+ if is_stable_diffusion :
119
+ submodels_for_export = _get_submodels_for_export_stable_diffusion (model )
120
+ elif model .config .is_encoder_decoder and task .startswith (TasksManager ._ENCODER_DECODER_TASKS ):
121
+ submodels_for_export = _get_submodels_for_export_encoder_decoder (
122
+ model , use_past = task .endswith ("-with-past" )
123
+ )
124
+ elif task .startswith ("text-generation" ):
125
+ model = patch_decoder_attention_mask (model )
126
+ models_and_onnx_configs = {"model" : model }
127
+ else :
128
+ submodels_for_export = {"model" : model }
129
+
130
+ if submodels_for_export .keys () != custom_onnx_configs .keys ():
131
+ logger .error (f"ONNX custom configs for: { ', ' .join (custom_onnx_configs .keys ())} " )
132
+ logger .error (f"Submodels to export: { ', ' .join (submodels_for_export .keys ())} " )
133
+ raise ValueError (
134
+ "Trying to export a custom model, but could not find as many custom ONNX configs as the number of submodels to export. Please specifiy the fn_get_submodels argument, that should return a dictionary of submodules with as many items as the provided custom_onnx_configs dictionary."
135
+ )
136
+
137
+ for key , custom_onnx_config in custom_onnx_configs .items ():
138
+ models_and_onnx_configs [key ] = (submodels_for_export [key ], custom_onnx_config )
139
+
140
+ # Default to the first ONNX config for stable-diffusion and custom architecture case.
141
+ if onnx_config is None :
142
+ onnx_config = next (iter (models_and_onnx_configs .values ()))[1 ]
143
+
144
+ return onnx_config , models_and_onnx_configs
145
+
146
+
41
147
def main_export (
42
148
model_name_or_path : str ,
43
149
output : Union [str , Path ],
@@ -183,7 +289,7 @@ def main_export(
183
289
f"If you want to support { model_type } please propose a PR or open up an issue."
184
290
)
185
291
if model .config .model_type .replace ("-" , "_" ) not in TasksManager .get_supported_model_type_for_task (
186
- task , exporter = "onnx "
292
+ task , exporter = "openvino "
187
293
):
188
294
custom_architecture = True
189
295
@@ -200,7 +306,7 @@ def main_export(
200
306
if (
201
307
not custom_architecture
202
308
and not is_stable_diffusion
203
- and task + "-with-past" in TasksManager .get_supported_tasks_for_model_type (model_type , "onnx " )
309
+ and task + "-with-past" in TasksManager .get_supported_tasks_for_model_type (model_type , "openvino " )
204
310
):
205
311
if original_task == "auto" : # Make -with-past the default if --task was not explicitely specified
206
312
task = task + "-with-past"
@@ -222,24 +328,15 @@ def main_export(
222
328
preprocessors = maybe_load_preprocessors (
223
329
model_name_or_path , subfolder = subfolder , trust_remote_code = trust_remote_code
224
330
)
225
- if not task .startswith ("text-generation" ):
226
- onnx_config , models_and_onnx_configs = optimum_main ._get_submodels_and_onnx_configs (
227
- model = model ,
228
- task = task ,
229
- monolith = False ,
230
- custom_onnx_configs = custom_onnx_configs if custom_onnx_configs is not None else {},
231
- custom_architecture = custom_architecture ,
232
- fn_get_submodels = fn_get_submodels ,
233
- preprocessors = preprocessors ,
234
- _variant = "default" ,
235
- )
236
- else :
237
- # TODO : ModelPatcher will be added in next optimum release
238
- model = patch_decoder_attention_mask (model )
239
-
240
- onnx_config_constructor = TasksManager .get_exporter_config_constructor (model = model , exporter = "onnx" , task = task )
241
- onnx_config = onnx_config_constructor (model .config )
242
- models_and_onnx_configs = {"model" : (model , onnx_config )}
331
+ onnx_config , models_and_onnx_configs = _get_submodels_and_export_configs (
332
+ model = model ,
333
+ task = task ,
334
+ custom_onnx_configs = custom_onnx_configs if custom_onnx_configs is not None else {},
335
+ custom_architecture = custom_architecture ,
336
+ fn_get_submodels = fn_get_submodels ,
337
+ preprocessors = preprocessors ,
338
+ _variant = "default" ,
339
+ )
243
340
244
341
if int8 is None :
245
342
int8 = False
@@ -276,7 +373,7 @@ def main_export(
276
373
generation_config = getattr (model , "generation_config" , None )
277
374
if generation_config is not None :
278
375
generation_config .save_pretrained (output )
279
- maybe_save_preprocessors (model_name_or_path , output )
376
+ maybe_save_preprocessors (model_name_or_path , output , trust_remote_code = trust_remote_code )
280
377
281
378
if model .config .is_encoder_decoder and task .startswith ("text-generation" ):
282
379
raise ValueError (
0 commit comments