Skip to content

Commit 7a3f338

Browse files
committed
move patching specs to onnx config
1 parent d867903 commit 7a3f338

File tree

3 files changed

+4
-5
lines changed

3 files changed

+4
-5
lines changed

optimum/exporters/base.py

-3
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@
3535
if TYPE_CHECKING:
3636
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
3737

38-
from .model_patcher import PatchingSpec
39-
4038
logger = logging.get_logger(__name__)
4139

4240

@@ -112,7 +110,6 @@ class ExportersConfig(ABC):
112110
ATOL_FOR_VALIDATION: Union[float, Dict[str, float]] = 1e-5
113111
MIN_TORCH_VERSION = GLOBAL_MIN_TORCH_VERSION
114112
MIN_TRANSFORMERS_VERSION = GLOBAL_MIN_TRANSFORMERS_VERSION
115-
PATCHING_SPECS: Optional[List["PatchingSpec"]] = None
116113
_TASK_TO_COMMON_OUTPUTS = {
117114
"audio-classification": ["logits"],
118115
"audio-frame-classification": ["logits"],

optimum/exporters/onnx/base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
if TYPE_CHECKING:
6161
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
6262

63+
from .model_patcher import PatchingSpec
64+
6365
if is_diffusers_available():
6466
from diffusers import ModelMixin
6567

@@ -102,7 +104,8 @@ class OnnxConfig(ExportersConfig):
102104
DEFAULT_ONNX_OPSET = 11
103105
VARIANTS = {"default": "The default ONNX variant."}
104106
DEFAULT_VARIANT = "default"
105-
107+
# TODO: move PATCHING_SPECS to ExportersConfig
108+
PATCHING_SPECS: Optional[List["PatchingSpec"]] = None
106109
_TASK_TO_COMMON_OUTPUTS = {
107110
"audio-classification": OrderedDict({"logits": {0: "batch_size"}}),
108111
"audio-frame-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),

optimum/exporters/onnx/model_configs.py

-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@
104104
from transformers import PretrainedConfig
105105
from transformers.modeling_utils import PreTrainedModel
106106

107-
108107
if is_tf_available():
109108
from transformers.modeling_tf_utils import TFPreTrainedModel
110109

0 commit comments

Comments
 (0)