|
18 | 18 | from tempfile import TemporaryDirectory
|
19 | 19 | from typing import Optional
|
20 | 20 |
|
| 21 | +import torch |
21 | 22 | from parameterized import parameterized
|
22 |
| -from transformers import AutoConfig |
| 23 | +from sentence_transformers import SentenceTransformer, models |
| 24 | +from transformers import AutoConfig, AutoTokenizer |
23 | 25 | from utils_tests import MODEL_NAMES
|
24 | 26 |
|
25 | 27 | from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
|
@@ -124,7 +126,7 @@ def test_export(self, model_type: str):
|
124 | 126 |
|
125 | 127 |
|
126 | 128 | class CustomExportModelTest(unittest.TestCase):
|
127 |
| - def test_export_custom_model(self): |
| 129 | + def test_custom_export_config_model(self): |
128 | 130 | class BertOnnxConfigWithPooler(BertOnnxConfig):
|
129 | 131 | @property
|
130 | 132 | def outputs(self):
|
@@ -157,3 +159,26 @@ def outputs(self):
|
157 | 159 |
|
158 | 160 | self.assertIsInstance(ov_model, OVBaseModel)
|
159 | 161 | self.assertTrue(ov_model.output_names == {"last_hidden_state": 0, "pooler_output": 1})
|
| 162 | + |
| 163 | + def test_export_custom_model(self): |
| 164 | + model_id = "hf-internal-testing/tiny-random-BertModel" |
| 165 | + word_embedding_model = models.Transformer(model_id, max_seq_length=256) |
| 166 | + pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) |
| 167 | + dense_model = models.Dense( |
| 168 | + in_features=pooling_model.get_sentence_embedding_dimension(), |
| 169 | + out_features=256, |
| 170 | + ) |
| 171 | + model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model]) |
| 172 | + |
| 173 | + with TemporaryDirectory() as tmpdirname: |
| 174 | + export_from_model(model, output=tmpdirname, task="feature-extraction") |
| 175 | + ov_model = OVModelForCustomTasks.from_pretrained(tmpdirname) |
| 176 | + |
| 177 | + tokenizer = AutoTokenizer.from_pretrained(model_id) |
| 178 | + tokens = tokenizer("This is a sample input", return_tensors="pt") |
| 179 | + with torch.no_grad(): |
| 180 | + model_outputs = model(tokens) |
| 181 | + |
| 182 | + ov_outputs = ov_model(**tokens) |
| 183 | + self.assertTrue(torch.allclose(ov_outputs.token_embeddings, model_outputs.token_embeddings, atol=1e-4)) |
| 184 | + self.assertTrue(torch.allclose(ov_outputs.sentence_embedding, model_outputs.sentence_embedding, atol=1e-4)) |
0 commit comments