Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 3ac9f99

Browse files
committedApr 19, 2024·
added a custom export test
1 parent 10e34cd commit 3ac9f99

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed
 

‎tests/openvino/test_export.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,18 @@
1919
from typing import Optional
2020

2121
from parameterized import parameterized
22+
from transformers import AutoConfig
2223
from utils_tests import MODEL_NAMES
2324

2425
from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
25-
from optimum.exporters.openvino import export_from_model
26+
from optimum.exporters.onnx.model_configs import BertOnnxConfig
27+
from optimum.exporters.openvino import export_from_model, main_export
2628
from optimum.exporters.tasks import TasksManager
2729
from optimum.intel import (
2830
OVLatentConsistencyModelPipeline,
2931
OVModelForAudioClassification,
3032
OVModelForCausalLM,
33+
OVModelForCustomTasks,
3134
OVModelForFeatureExtraction,
3235
OVModelForImageClassification,
3336
OVModelForMaskedLM,
@@ -114,3 +117,35 @@ def _openvino_export(
114117
@parameterized.expand(SUPPORTED_ARCHITECTURES)
115118
def test_export(self, model_type: str):
116119
self._openvino_export(model_type)
120+
121+
122+
class CustomExportModelTest(unittest.TestCase):
123+
def test_export_custom_model(self):
124+
class BertOnnxConfigWithPooler(BertOnnxConfig):
125+
@property
126+
def outputs(self):
127+
common_outputs = {}
128+
common_outputs["last_hidden_state"] = {0: "batch_size", 1: "sequence_length"}
129+
common_outputs["pooler_output"] = {0: "batch_size"}
130+
return common_outputs
131+
132+
base_task = "feature-extraction"
133+
custom_task = f"{base_task}-with-pooler"
134+
model_id = "sentence-transformers/all-MiniLM-L6-v2"
135+
136+
config = AutoConfig.from_pretrained(model_id)
137+
custom_export_configs = {"model": BertOnnxConfigWithPooler(config, task=base_task)}
138+
139+
with TemporaryDirectory() as tmpdirname:
140+
main_export(
141+
model_name_or_path=model_id,
142+
custom_export_configs=custom_export_configs,
143+
library_name="transformers",
144+
output=Path(tmpdirname),
145+
task=base_task,
146+
)
147+
148+
ov_model = OVModelForCustomTasks.from_pretrained(tmpdirname)
149+
150+
self.assertIsInstance(ov_model, OVBaseModel)
151+
self.assertTrue(ov_model.output_names == {"last_hidden_state": 0, "pooler_output": 1})

0 commit comments

Comments
 (0)
Please sign in to comment.