|
49 | 49 | )
|
50 | 50 | from ..base import ExportConfig
|
51 | 51 | from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
|
52 |
| -from .model_patcher import DecoderModelPatcher, Seq2SeqModelPatcher |
| 52 | +from .model_patcher import DecoderModelPatcher, ModelPatcher, Seq2SeqModelPatcher |
53 | 53 |
|
54 | 54 |
|
55 | 55 | # TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization
|
|
100 | 100 | """
|
101 | 101 |
|
102 | 102 |
|
103 |
| -class OnnxConfig(ExportConfig): |
| 103 | +class OnnxConfig(ExportConfig, ABC): |
104 | 104 | DEFAULT_ONNX_OPSET = 11
|
105 | 105 | VARIANTS = {"default": "The default ONNX variant."}
|
106 | 106 | DEFAULT_VARIANT = "default"
|
107 | 107 | # TODO: move PATCHING_SPECS to ExportConfig
|
108 | 108 | PATCHING_SPECS: Optional[List["PatchingSpec"]] = None
|
| 109 | + _MODEL_PATCHER = ModelPatcher |
| 110 | + |
109 | 111 | _TASK_TO_COMMON_OUTPUTS = {
|
110 | 112 | "audio-classification": OrderedDict({"logits": {0: "batch_size"}}),
|
111 | 113 | "audio-frame-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
|
@@ -425,6 +427,11 @@ def post_process_exported_models(
|
425 | 427 |
|
426 | 428 | return models_and_onnx_configs, onnx_files_subpaths
|
427 | 429 |
|
| 430 | + def patch_model_for_export( |
| 431 | + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None |
| 432 | + ) -> ModelPatcher: |
| 433 | + return self._MODEL_PATCHER(self, model, model_kwargs=model_kwargs) |
| 434 | + |
428 | 435 |
|
429 | 436 | class OnnxConfigWithPast(OnnxConfig, ABC):
|
430 | 437 | """
|
|
0 commit comments