Skip to content

Commit e7d8c1e

Browse files
committed
fix
1 parent 46aac95 commit e7d8c1e

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

optimum/exporters/onnx/base.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
)
5050
from ..base import ExportConfig
5151
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
52-
from .model_patcher import DecoderModelPatcher, Seq2SeqModelPatcher
52+
from .model_patcher import DecoderModelPatcher, ModelPatcher, Seq2SeqModelPatcher
5353

5454

5555
# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization
@@ -100,12 +100,14 @@
100100
"""
101101

102102

103-
class OnnxConfig(ExportConfig):
103+
class OnnxConfig(ExportConfig, ABC):
104104
DEFAULT_ONNX_OPSET = 11
105105
VARIANTS = {"default": "The default ONNX variant."}
106106
DEFAULT_VARIANT = "default"
107107
# TODO: move PATCHING_SPECS to ExportConfig
108108
PATCHING_SPECS: Optional[List["PatchingSpec"]] = None
109+
_MODEL_PATCHER = ModelPatcher
110+
109111
_TASK_TO_COMMON_OUTPUTS = {
110112
"audio-classification": OrderedDict({"logits": {0: "batch_size"}}),
111113
"audio-frame-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
@@ -425,6 +427,11 @@ def post_process_exported_models(
425427

426428
return models_and_onnx_configs, onnx_files_subpaths
427429

430+
def patch_model_for_export(
431+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
432+
) -> ModelPatcher:
433+
return self._MODEL_PATCHER(self, model, model_kwargs=model_kwargs)
434+
428435

429436
class OnnxConfigWithPast(OnnxConfig, ABC):
430437
"""

0 commit comments

Comments
 (0)