|
19 | 19 | from typing import Optional
|
20 | 20 |
|
21 | 21 | from parameterized import parameterized
|
| 22 | +from transformers import AutoConfig |
22 | 23 | from utils_tests import MODEL_NAMES
|
23 | 24 |
|
24 | 25 | from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
|
25 |
| -from optimum.exporters.openvino import export_from_model |
| 26 | +from optimum.exporters.onnx.model_configs import BertOnnxConfig |
| 27 | +from optimum.exporters.openvino import export_from_model, main_export |
26 | 28 | from optimum.exporters.tasks import TasksManager
|
27 | 29 | from optimum.intel import (
|
28 | 30 | OVLatentConsistencyModelPipeline,
|
29 | 31 | OVModelForAudioClassification,
|
30 | 32 | OVModelForCausalLM,
|
| 33 | + OVModelForCustomTasks, |
31 | 34 | OVModelForFeatureExtraction,
|
32 | 35 | OVModelForImageClassification,
|
33 | 36 | OVModelForMaskedLM,
|
@@ -114,3 +117,35 @@ def _openvino_export(
|
114 | 117 | @parameterized.expand(SUPPORTED_ARCHITECTURES)
|
115 | 118 | def test_export(self, model_type: str):
|
116 | 119 | self._openvino_export(model_type)
|
| 120 | + |
| 121 | + |
| 122 | +class CustomExportModelTest(unittest.TestCase): |
| 123 | + def test_export_custom_model(self): |
| 124 | + class BertOnnxConfigWithPooler(BertOnnxConfig): |
| 125 | + @property |
| 126 | + def outputs(self): |
| 127 | + common_outputs = {} |
| 128 | + common_outputs["last_hidden_state"] = {0: "batch_size", 1: "sequence_length"} |
| 129 | + common_outputs["pooler_output"] = {0: "batch_size"} |
| 130 | + return common_outputs |
| 131 | + |
| 132 | + base_task = "feature-extraction" |
| 133 | + custom_task = f"{base_task}-with-pooler" |
| 134 | + model_id = "sentence-transformers/all-MiniLM-L6-v2" |
| 135 | + |
| 136 | + config = AutoConfig.from_pretrained(model_id) |
| 137 | + custom_export_configs = {"model": BertOnnxConfigWithPooler(config, task=base_task)} |
| 138 | + |
| 139 | + with TemporaryDirectory() as tmpdirname: |
| 140 | + main_export( |
| 141 | + model_name_or_path=model_id, |
| 142 | + custom_export_configs=custom_export_configs, |
| 143 | + library_name="transformers", |
| 144 | + output=Path(tmpdirname), |
| 145 | + task=base_task, |
| 146 | + ) |
| 147 | + |
| 148 | + ov_model = OVModelForCustomTasks.from_pretrained(tmpdirname) |
| 149 | + |
| 150 | + self.assertIsInstance(ov_model, OVBaseModel) |
| 151 | + self.assertTrue(ov_model.output_names == {"last_hidden_state": 0, "pooler_output": 1}) |
0 commit comments