Skip to content

Commit 813d7c0

Browse files
authored
Add custom model export test (#677)
* Add custom model export test * format
1 parent 683133f commit 813d7c0

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
"transformers_stream_generator",
5454
"einops",
5555
"tiktoken",
56-
"sentence_transformers",
56+
"sentence-transformers",
5757
]
5858

5959
QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241"]

tests/openvino/test_export.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
from tempfile import TemporaryDirectory
1919
from typing import Optional
2020

21+
import torch
2122
from parameterized import parameterized
22-
from transformers import AutoConfig
23+
from sentence_transformers import SentenceTransformer, models
24+
from transformers import AutoConfig, AutoTokenizer
2325
from utils_tests import MODEL_NAMES
2426

2527
from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
@@ -124,7 +126,7 @@ def test_export(self, model_type: str):
124126

125127

126128
class CustomExportModelTest(unittest.TestCase):
127-
def test_export_custom_model(self):
129+
def test_custom_export_config_model(self):
128130
class BertOnnxConfigWithPooler(BertOnnxConfig):
129131
@property
130132
def outputs(self):
@@ -157,3 +159,26 @@ def outputs(self):
157159

158160
self.assertIsInstance(ov_model, OVBaseModel)
159161
self.assertTrue(ov_model.output_names == {"last_hidden_state": 0, "pooler_output": 1})
162+
163+
def test_export_custom_model(self):
164+
model_id = "hf-internal-testing/tiny-random-BertModel"
165+
word_embedding_model = models.Transformer(model_id, max_seq_length=256)
166+
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
167+
dense_model = models.Dense(
168+
in_features=pooling_model.get_sentence_embedding_dimension(),
169+
out_features=256,
170+
)
171+
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model])
172+
173+
with TemporaryDirectory() as tmpdirname:
174+
export_from_model(model, output=tmpdirname, task="feature-extraction")
175+
ov_model = OVModelForCustomTasks.from_pretrained(tmpdirname)
176+
177+
tokenizer = AutoTokenizer.from_pretrained(model_id)
178+
tokens = tokenizer("This is a sample input", return_tensors="pt")
179+
with torch.no_grad():
180+
model_outputs = model(tokens)
181+
182+
ov_outputs = ov_model(**tokens)
183+
self.assertTrue(torch.allclose(ov_outputs.token_embeddings, model_outputs.token_embeddings, atol=1e-4))
184+
self.assertTrue(torch.allclose(ov_outputs.sentence_embedding, model_outputs.sentence_embedding, atol=1e-4))

0 commit comments

Comments
 (0)