14
14
# limitations under the License.
15
15
"""Base exporters config."""
16
16
17
- from abc import ABC
18
-
19
-
20
-
21
17
import copy
22
- import enum
23
- import gc
24
- import inspect
25
- import itertools
26
- import os
27
- import re
28
18
from abc import ABC , abstractmethod
29
- from collections import OrderedDict
30
- from pathlib import Path
31
- from typing import TYPE_CHECKING , Any , Dict , Iterable , List , Optional , Tuple , Union
19
+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Union
32
20
33
- import numpy as np
34
- from transformers .utils import is_accelerate_available , is_torch_available
21
+ from transformers .utils import is_torch_available
35
22
36
23
37
24
if is_torch_available ():
38
- import torch . nn as nn
25
+ pass
39
26
40
27
from .utils import (
41
28
DEFAULT_DUMMY_SHAPES ,
46
33
from .utils import TRANSFORMERS_MINIMUM_VERSION as GLOBAL_MIN_TRANSFORMERS_VERSION
47
34
from .utils .doc import add_dynamic_docstring
48
35
from .utils .import_utils import is_torch_version , is_transformers_version
49
- from . error_utils import MissingMandatoryAxisDimension
36
+
50
37
51
38
# from .model_patcher import ModelPatcher
52
39
53
40
if TYPE_CHECKING :
54
- from transformers import PretrainedConfig , PreTrainedModel , TFPreTrainedModel
41
+ from transformers import PretrainedConfig
55
42
56
43
from .model_patcher import PatchingSpec
57
44
58
45
logger = logging .get_logger (__name__ )
59
46
60
47
61
-
62
48
GENERATE_DUMMY_DOCSTRING = r"""
63
49
Generates the dummy inputs necessary for tracing the model. If not explicitely specified, default input shapes are used.
64
50
90
76
"""
91
77
92
78
93
-
94
79
# TODO: Remove
95
80
class ExportConfig (ABC ):
96
81
pass
97
82
98
83
99
-
100
84
class ExportersConfig (ABC ):
101
85
"""
102
86
Base class describing metadata on how to export the model through the ONNX format.
@@ -141,19 +125,19 @@ class ExportersConfig(ABC):
141
125
"audio-xvector" : ["logits" ], # for onnx : ["logits", "embeddings"]
142
126
"depth-estimation" : ["predicted_depth" ],
143
127
"document-question-answering" : ["logits" ],
144
- "feature-extraction" : ["last_hidden_state" ], # for neuron : ["last_hidden_state", "pooler_output"]
128
+ "feature-extraction" : ["last_hidden_state" ], # for neuron : ["last_hidden_state", "pooler_output"]
145
129
"fill-mask" : ["logits" ],
146
130
"image-classification" : ["logits" ],
147
131
"image-segmentation" : ["logits" ], # for tflite : ["logits", "pred_boxes", "pred_masks"]
148
132
"image-to-text" : ["logits" ],
149
133
"image-to-image" : ["reconstruction" ],
150
134
"mask-generation" : ["logits" ],
151
- "masked-im" : ["logits" ], # for onnx : ["reconstruction"]
135
+ "masked-im" : ["logits" ], # for onnx : ["reconstruction"]
152
136
"multiple-choice" : ["logits" ],
153
137
"object-detection" : ["logits" , "pred_boxes" ],
154
138
"question-answering" : ["start_logits" , "end_logits" ],
155
139
"semantic-segmentation" : ["logits" ],
156
- "text2text-generation" : ["logits" ], # for tflite : ["logits", "encoder_last_hidden_state"],
140
+ "text2text-generation" : ["logits" ], # for tflite : ["logits", "encoder_last_hidden_state"],
157
141
"text-classification" : ["logits" ],
158
142
"text-generation" : ["logits" ],
159
143
"time-series-forecasting" : ["prediction_outputs" ],
@@ -179,7 +163,6 @@ def __init__(
179
163
self .mandatory_axes = ()
180
164
self ._axes : Dict [str , int ] = {}
181
165
182
-
183
166
def _create_dummy_input_generator_classes (self , ** kwargs ) -> List [DummyInputGenerator ]:
184
167
"""
185
168
Instantiates the dummy input generators from `self.DUMMY_INPUT_GENERATOR_CLASSES`.
@@ -190,7 +173,6 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGene
190
173
# self._validate_mandatory_axes()
191
174
return [cls_ (self .task , self ._normalized_config , ** kwargs ) for cls_ in self .DUMMY_INPUT_GENERATOR_CLASSES ]
192
175
193
-
194
176
@property
195
177
@abstractmethod
196
178
def inputs (self ) -> Dict [str , Dict [int , str ]]:
@@ -213,7 +195,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
213
195
common_outputs = self ._TASK_TO_COMMON_OUTPUTS [self .task ]
214
196
return copy .deepcopy (common_outputs )
215
197
216
-
217
198
@property
218
199
def values_override (self ) -> Optional [Dict [str , Any ]]:
219
200
"""
@@ -251,18 +232,15 @@ def is_torch_support_available(self) -> bool:
251
232
252
233
return False
253
234
254
-
255
235
@add_dynamic_docstring (text = GENERATE_DUMMY_DOCSTRING , dynamic_elements = DEFAULT_DUMMY_SHAPES )
256
236
def generate_dummy_inputs (self , framework : str = "pt" , ** kwargs ) -> Dict :
257
-
258
237
"""
259
238
Generates dummy inputs that the exported model should be able to process.
260
239
This method is actually used to determine the input specs that are needed for the export.
261
240
262
241
Returns:
263
242
`Dict[str, [tf.Tensor, torch.Tensor]]`: A dictionary mapping input names to dummy tensors.
264
243
"""
265
-
266
244
267
245
dummy_inputs_generators = self ._create_dummy_input_generator_classes (** kwargs )
268
246
dummy_inputs = {}
@@ -303,5 +281,4 @@ def flatten_inputs(cls, inputs: Dict[str, Any]) -> Dict[str, Any]:
303
281
# ) -> ModelPatcher:
304
282
# return ModelPatcher(self, model, model_kwargs=model_kwargs)
305
283
306
-
307
284
############################################################################################################################################################
0 commit comments