Skip to content

Commit 46f1c26

Browse files
committed
fix style
1 parent 58b906f commit 46f1c26

File tree

4 files changed

+12
-45
lines changed

4 files changed

+12
-45
lines changed

optimum/exporters/base.py

+8-31
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,15 @@
1414
# limitations under the License.
1515
"""Base exporters config."""
1616

17-
from abc import ABC
18-
19-
20-
2117
import copy
22-
import enum
23-
import gc
24-
import inspect
25-
import itertools
26-
import os
27-
import re
2818
from abc import ABC, abstractmethod
29-
from collections import OrderedDict
30-
from pathlib import Path
31-
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
19+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
3220

33-
import numpy as np
34-
from transformers.utils import is_accelerate_available, is_torch_available
21+
from transformers.utils import is_torch_available
3522

3623

3724
if is_torch_available():
38-
import torch.nn as nn
25+
pass
3926

4027
from .utils import (
4128
DEFAULT_DUMMY_SHAPES,
@@ -46,19 +33,18 @@
4633
from .utils import TRANSFORMERS_MINIMUM_VERSION as GLOBAL_MIN_TRANSFORMERS_VERSION
4734
from .utils.doc import add_dynamic_docstring
4835
from .utils.import_utils import is_torch_version, is_transformers_version
49-
from .error_utils import MissingMandatoryAxisDimension
36+
5037

5138
# from .model_patcher import ModelPatcher
5239

5340
if TYPE_CHECKING:
54-
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
41+
from transformers import PretrainedConfig
5542

5643
from .model_patcher import PatchingSpec
5744

5845
logger = logging.get_logger(__name__)
5946

6047

61-
6248
GENERATE_DUMMY_DOCSTRING = r"""
6349
Generates the dummy inputs necessary for tracing the model. If not explicitely specified, default input shapes are used.
6450
@@ -90,13 +76,11 @@
9076
"""
9177

9278

93-
9479
# TODO: Remove
9580
class ExportConfig(ABC):
9681
pass
9782

9883

99-
10084
class ExportersConfig(ABC):
10185
"""
10286
Base class describing metadata on how to export the model through the ONNX format.
@@ -141,19 +125,19 @@ class ExportersConfig(ABC):
141125
"audio-xvector": ["logits"], # for onnx : ["logits", "embeddings"]
142126
"depth-estimation": ["predicted_depth"],
143127
"document-question-answering": ["logits"],
144-
"feature-extraction": ["last_hidden_state"], # for neuron : ["last_hidden_state", "pooler_output"]
128+
"feature-extraction": ["last_hidden_state"], # for neuron : ["last_hidden_state", "pooler_output"]
145129
"fill-mask": ["logits"],
146130
"image-classification": ["logits"],
147131
"image-segmentation": ["logits"], # for tflite : ["logits", "pred_boxes", "pred_masks"]
148132
"image-to-text": ["logits"],
149133
"image-to-image": ["reconstruction"],
150134
"mask-generation": ["logits"],
151-
"masked-im": ["logits"], # for onnx : ["reconstruction"]
135+
"masked-im": ["logits"], # for onnx : ["reconstruction"]
152136
"multiple-choice": ["logits"],
153137
"object-detection": ["logits", "pred_boxes"],
154138
"question-answering": ["start_logits", "end_logits"],
155139
"semantic-segmentation": ["logits"],
156-
"text2text-generation": ["logits"], # for tflite : ["logits", "encoder_last_hidden_state"],
140+
"text2text-generation": ["logits"], # for tflite : ["logits", "encoder_last_hidden_state"],
157141
"text-classification": ["logits"],
158142
"text-generation": ["logits"],
159143
"time-series-forecasting": ["prediction_outputs"],
@@ -179,7 +163,6 @@ def __init__(
179163
self.mandatory_axes = ()
180164
self._axes: Dict[str, int] = {}
181165

182-
183166
def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]:
184167
"""
185168
Instantiates the dummy input generators from `self.DUMMY_INPUT_GENERATOR_CLASSES`.
@@ -190,7 +173,6 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGene
190173
# self._validate_mandatory_axes()
191174
return [cls_(self.task, self._normalized_config, **kwargs) for cls_ in self.DUMMY_INPUT_GENERATOR_CLASSES]
192175

193-
194176
@property
195177
@abstractmethod
196178
def inputs(self) -> Dict[str, Dict[int, str]]:
@@ -213,7 +195,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
213195
common_outputs = self._TASK_TO_COMMON_OUTPUTS[self.task]
214196
return copy.deepcopy(common_outputs)
215197

216-
217198
@property
218199
def values_override(self) -> Optional[Dict[str, Any]]:
219200
"""
@@ -251,18 +232,15 @@ def is_torch_support_available(self) -> bool:
251232

252233
return False
253234

254-
255235
@add_dynamic_docstring(text=GENERATE_DUMMY_DOCSTRING, dynamic_elements=DEFAULT_DUMMY_SHAPES)
256236
def generate_dummy_inputs(self, framework: str = "pt", **kwargs) -> Dict:
257-
258237
"""
259238
Generates dummy inputs that the exported model should be able to process.
260239
This method is actually used to determine the input specs that are needed for the export.
261240
262241
Returns:
263242
`Dict[str, [tf.Tensor, torch.Tensor]]`: A dictionary mapping input names to dummy tensors.
264243
"""
265-
266244

267245
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)
268246
dummy_inputs = {}
@@ -303,5 +281,4 @@ def flatten_inputs(cls, inputs: Dict[str, Any]) -> Dict[str, Any]:
303281
# ) -> ModelPatcher:
304282
# return ModelPatcher(self, model, model_kwargs=model_kwargs)
305283

306-
307284
############################################################################################################################################################

optimum/exporters/onnx/base.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import itertools
2222
import os
2323
import re
24-
from abc import ABC, abstractmethod
24+
from abc import ABC
2525
from collections import OrderedDict
2626
from pathlib import Path
2727
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
@@ -41,16 +41,13 @@
4141
is_diffusers_available,
4242
logging,
4343
)
44-
from ...utils import TORCH_MINIMUM_VERSION as GLOBAL_MIN_TORCH_VERSION
45-
from ...utils import TRANSFORMERS_MINIMUM_VERSION as GLOBAL_MIN_TRANSFORMERS_VERSION
4644
from ...utils.doc import add_dynamic_docstring
4745
from ...utils.import_utils import (
4846
is_onnx_available,
4947
is_onnxruntime_available,
50-
is_torch_version,
5148
is_transformers_version,
5249
)
53-
from ..base import ExportConfig, ExportersConfig
50+
from ..base import ExportersConfig
5451
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
5552
from .model_patcher import ModelPatcher, Seq2SeqModelPatcher
5653

@@ -66,7 +63,6 @@
6663
if is_diffusers_available():
6764
from diffusers import ModelMixin
6865

69-
from .model_patcher import PatchingSpec
7066

7167
logger = logging.get_logger(__name__)
7268

@@ -103,7 +99,6 @@
10399

104100

105101
class OnnxConfig(ExportersConfig):
106-
107102
DEFAULT_ONNX_OPSET = 11
108103
VARIANTS = {"default": "The default ONNX variant."}
109104
DEFAULT_VARIANT = "default"
@@ -281,7 +276,6 @@ def patch_model_for_export(
281276
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
282277
) -> ModelPatcher:
283278
return ModelPatcher(self, model, model_kwargs=model_kwargs)
284-
285279

286280
@property
287281
def torch_to_onnx_input_map(self) -> Dict[str, str]:
@@ -348,8 +342,6 @@ def ordered_inputs(self, model: Union["PreTrainedModel", "TFPreTrainedModel"]) -
348342
ordered_inputs[name] = dynamic_axes
349343
return ordered_inputs
350344

351-
352-
353345
# TODO: use instead flatten_inputs and remove
354346
@classmethod
355347
def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) -> Dict[str, Any]:

optimum/exporters/tflite/base.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
"""TensorFlow Lite configuration base classes."""
1616

17-
from abc import ABC, abstractmethod
17+
from abc import ABC
1818
from ctypes import ArgumentError
1919
from dataclasses import dataclass
2020
from enum import Enum
@@ -192,7 +192,6 @@ def __init__(
192192
point_batch_size: Optional[int] = None,
193193
nb_points_per_image: Optional[int] = None,
194194
):
195-
196195
super().__init__(config=config, task=task, int_dtype="int64", float_dtype="fp32")
197196

198197
# self.mandatory_axes = ()
@@ -269,7 +268,6 @@ def _create_dummy_input_generator_classes(self) -> List["DummyInputGenerator"]:
269268
def generate_dummy_inputs(self) -> Dict[str, "tf.Tensor"]:
270269
return super().generate_dummy_inputs(framework="tf")
271270

272-
273271
@property
274272
def inputs_specs(self) -> List["TensorSpec"]:
275273
"""

optimum/exporters/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def get_speecht5_models_for_export(
545545
use_past=use_past,
546546
use_past_in_inputs=False, # Irrelevant here.
547547
behavior=config._behavior, # Irrelevant here.
548-
preprocessors=config._preprocessors,
548+
# preprocessors=config._preprocessors,
549549
is_postnet_and_vocoder=True,
550550
legacy=config.legacy,
551551
)

0 commit comments

Comments
 (0)