Skip to content

Commit f392c9b

Browse files
Custom tasks modeling (#669)
* added custom tasks modeling * patched output names for now and added vit with a attentions test * test passing * fix attentions * added hidden states test * remove unnecessary names processing * better testing * added inputs check Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * added a bert with pooler * fix name * added a custom export test * better custom config --------- Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent ff1d94b commit f392c9b

File tree

6 files changed

+195
-2
lines changed

6 files changed

+195
-2
lines changed

optimum/intel/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,9 @@
116116
"OVModelForAudioClassification",
117117
"OVModelForAudioFrameClassification",
118118
"OVModelForAudioXVector",
119-
"OVModelForCTC",
120119
"OVModelForCausalLM",
120+
"OVModelForCTC",
121+
"OVModelForCustomTasks",
121122
"OVModelForFeatureExtraction",
122123
"OVModelForImageClassification",
123124
"OVModelForMaskedLM",
@@ -242,6 +243,7 @@
242243
OVModelForAudioXVector,
243244
OVModelForCausalLM,
244245
OVModelForCTC,
246+
OVModelForCustomTasks,
245247
OVModelForFeatureExtraction,
246248
OVModelForImageClassification,
247249
OVModelForMaskedLM,

optimum/intel/openvino/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
OVModelForAudioFrameClassification,
5050
OVModelForAudioXVector,
5151
OVModelForCTC,
52+
OVModelForCustomTasks,
5253
OVModelForFeatureExtraction,
5354
OVModelForImageClassification,
5455
OVModelForMaskedLM,

optimum/intel/openvino/modeling.py

+64
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
CausalLMOutput,
4444
ImageClassifierOutput,
4545
MaskedLMOutput,
46+
ModelOutput,
4647
QuestionAnsweringModelOutput,
4748
SequenceClassifierOutput,
4849
TokenClassifierOutput,
@@ -953,3 +954,66 @@ def forward(
953954
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
954955

955956
return TokenClassifierOutput(logits=logits)
957+
958+
959+
CUSTOM_TASKS_EXAMPLE = """
960+
Example of custom tasks (e.g. a sentence transformers with a pooler head):
961+
962+
```python
963+
>>> from transformers import {processor_class}
964+
>>> from optimum.intel import {model_class}
965+
966+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
967+
>>> model = {model_class}.from_pretrained("{checkpoint}")
968+
969+
>>> inputs = tokenizer("I love burritos!", return_tensors="np")
970+
971+
>>> outputs = model(**inputs)
972+
>>> last_hidden_state = outputs.last_hidden_state
973+
>>> pooler_output = outputs.pooler_output
974+
```
975+
"""
976+
977+
978+
@add_start_docstrings(
979+
"""
980+
OpenVINO Model for custom tasks. It can be used to leverage the inference acceleration for any single-file OpenVINO model, that may use custom inputs and outputs.
981+
""",
982+
MODEL_START_DOCSTRING,
983+
)
984+
class OVModelForCustomTasks(OVModel):
985+
@add_start_docstrings_to_model_forward(
986+
CUSTOM_TASKS_EXAMPLE.format(
987+
processor_class=_TOKENIZER_FOR_DOC,
988+
model_class="OVModelForCustomTasks",
989+
checkpoint="IlyasMoutawwakil/sbert-all-MiniLM-L6-v2-with-pooler",
990+
)
991+
)
992+
def forward(self, **kwargs):
993+
expected_inputs_names = set(self.input_names)
994+
inputs_names = set(kwargs)
995+
996+
if not expected_inputs_names.issubset(inputs_names):
997+
raise ValueError(
998+
f"Got unexpected inputs: expecting the following inputs : {', '.join(expected_inputs_names)} but got : {', '.join(inputs_names)}."
999+
)
1000+
1001+
np_inputs = isinstance(next(iter(kwargs.values())), np.ndarray)
1002+
inputs = {}
1003+
for input_name in self.input_names:
1004+
inputs[input_name] = np.array(kwargs.pop(input_name)) if not np_inputs else kwargs.pop(input_name)
1005+
1006+
outputs = self.request(inputs)
1007+
1008+
model_outputs = {}
1009+
for key, value in outputs.items():
1010+
key_name = next(iter(key.names))
1011+
if "." in key_name:
1012+
key_name = key_name.split(".")[0]
1013+
if key_name not in model_outputs:
1014+
model_outputs[key_name] = []
1015+
model_outputs[key_name].append(torch.from_numpy(value).to(self.device) if not np_inputs else value)
1016+
else:
1017+
model_outputs[key_name] = torch.from_numpy(value).to(self.device) if not np_inputs else value
1018+
1019+
return ModelOutput(**model_outputs)

tests/openvino/test_export.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,18 @@
1919
from typing import Optional
2020

2121
from parameterized import parameterized
22+
from transformers import AutoConfig
2223
from utils_tests import MODEL_NAMES
2324

2425
from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
25-
from optimum.exporters.openvino import export_from_model
26+
from optimum.exporters.onnx.model_configs import BertOnnxConfig
27+
from optimum.exporters.openvino import export_from_model, main_export
2628
from optimum.exporters.tasks import TasksManager
2729
from optimum.intel import (
2830
OVLatentConsistencyModelPipeline,
2931
OVModelForAudioClassification,
3032
OVModelForCausalLM,
33+
OVModelForCustomTasks,
3134
OVModelForFeatureExtraction,
3235
OVModelForImageClassification,
3336
OVModelForMaskedLM,
@@ -114,3 +117,39 @@ def _openvino_export(
114117
@parameterized.expand(SUPPORTED_ARCHITECTURES)
115118
def test_export(self, model_type: str):
116119
self._openvino_export(model_type)
120+
121+
122+
class CustomExportModelTest(unittest.TestCase):
123+
def test_export_custom_model(self):
124+
class BertOnnxConfigWithPooler(BertOnnxConfig):
125+
@property
126+
def outputs(self):
127+
if self.task == "feature-extraction-with-pooler":
128+
common_outputs = {}
129+
common_outputs["last_hidden_state"] = {0: "batch_size", 1: "sequence_length"}
130+
common_outputs["pooler_output"] = {0: "batch_size"}
131+
else:
132+
common_outputs = super().outputs
133+
134+
return common_outputs
135+
136+
base_task = "feature-extraction"
137+
custom_task = f"{base_task}-with-pooler"
138+
model_id = "sentence-transformers/all-MiniLM-L6-v2"
139+
140+
config = AutoConfig.from_pretrained(model_id)
141+
custom_export_configs = {"model": BertOnnxConfigWithPooler(config, task=custom_task)}
142+
143+
with TemporaryDirectory() as tmpdirname:
144+
main_export(
145+
model_name_or_path=model_id,
146+
custom_export_configs=custom_export_configs,
147+
library_name="transformers",
148+
output=Path(tmpdirname),
149+
task=base_task,
150+
)
151+
152+
ov_model = OVModelForCustomTasks.from_pretrained(tmpdirname)
153+
154+
self.assertIsInstance(ov_model, OVBaseModel)
155+
self.assertTrue(ov_model.output_names == {"last_hidden_state": 0, "pooler_output": 1})

tests/openvino/test_modeling.py

+85
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
OVModelForAudioXVector,
6464
OVModelForCausalLM,
6565
OVModelForCTC,
66+
OVModelForCustomTasks,
6667
OVModelForFeatureExtraction,
6768
OVModelForImageClassification,
6869
OVModelForMaskedLM,
@@ -1525,3 +1526,87 @@ def test_pipeline_image_to_text(self, model_arch: str):
15251526
self.assertIsInstance(outputs[0]["generated_text"], str)
15261527

15271528
gc.collect()
1529+
1530+
1531+
class OVModelForCustomTasksIntegrationTest(unittest.TestCase):
1532+
SUPPORTED_ARCHITECTURES_WITH_ATTENTION = ["vit-with-attentions"]
1533+
SUPPORTED_ARCHITECTURES_WITH_HIDDEN_STATES = ["vit-with-hidden-states"]
1534+
1535+
def _get_sample_image(self):
1536+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1537+
image = Image.open(requests.get(url, stream=True).raw)
1538+
return image
1539+
1540+
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_ATTENTION)
1541+
def test_compare_output_attentions(self, model_arch):
1542+
model_id = MODEL_NAMES[model_arch]
1543+
1544+
image = self._get_sample_image()
1545+
preprocessor = AutoFeatureExtractor.from_pretrained(model_id)
1546+
inputs = preprocessor(images=image, return_tensors="pt")
1547+
1548+
transformers_model = AutoModelForImageClassification.from_pretrained(model_id)
1549+
transformers_model.eval()
1550+
with torch.no_grad():
1551+
transformers_outputs = transformers_model(**inputs, output_attentions=True)
1552+
1553+
ov_model = OVModelForCustomTasks.from_pretrained(model_id, ov_config=F32_CONFIG)
1554+
self.assertIsInstance(ov_model.config, PretrainedConfig)
1555+
1556+
for input_type in ["pt", "np"]:
1557+
inputs = preprocessor(images=image, return_tensors=input_type)
1558+
ov_outputs = ov_model(**inputs)
1559+
self.assertIn("logits", ov_outputs)
1560+
self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type])
1561+
self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4))
1562+
self.assertTrue(len(ov_outputs.attentions) == len(transformers_outputs.attentions))
1563+
for i in range(len(ov_outputs.attentions)):
1564+
self.assertTrue(
1565+
torch.allclose(
1566+
torch.Tensor(ov_outputs.attentions[i]),
1567+
transformers_outputs.attentions[i],
1568+
atol=1e-4, # attentions are accurate
1569+
rtol=1e-4, # attentions are accurate
1570+
),
1571+
f"Attention mismatch at layer {i}",
1572+
)
1573+
1574+
del transformers_model
1575+
del ov_model
1576+
gc.collect()
1577+
1578+
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HIDDEN_STATES)
1579+
def test_compare_output_hidden_states(self, model_arch):
1580+
model_id = MODEL_NAMES[model_arch]
1581+
1582+
image = self._get_sample_image()
1583+
preprocessor = AutoFeatureExtractor.from_pretrained(model_id)
1584+
inputs = preprocessor(images=image, return_tensors="pt")
1585+
1586+
transformers_model = AutoModelForImageClassification.from_pretrained(model_id)
1587+
transformers_model.eval()
1588+
with torch.no_grad():
1589+
transformers_outputs = transformers_model(**inputs, output_hidden_states=True)
1590+
1591+
ov_model = OVModelForCustomTasks.from_pretrained(model_id, ov_config=F32_CONFIG)
1592+
self.assertIsInstance(ov_model.config, PretrainedConfig)
1593+
for input_type in ["pt", "np"]:
1594+
inputs = preprocessor(images=image, return_tensors=input_type)
1595+
ov_outputs = ov_model(**inputs)
1596+
self.assertIn("logits", ov_outputs)
1597+
self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type])
1598+
self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4))
1599+
self.assertTrue(len(ov_outputs.hidden_states) == len(transformers_outputs.hidden_states))
1600+
for i in range(len(ov_outputs.hidden_states)):
1601+
self.assertTrue(
1602+
torch.allclose(
1603+
torch.Tensor(ov_outputs.hidden_states[i]),
1604+
transformers_outputs.hidden_states[i],
1605+
atol=1e-3, # hidden states are less accurate
1606+
rtol=1e-2, # hidden states are less accurate
1607+
),
1608+
f"Hidden states mismatch at layer {i}",
1609+
)
1610+
del transformers_model
1611+
del ov_model
1612+
gc.collect()

tests/openvino/utils_tests.py

+2
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@
100100
"unispeech": "hf-internal-testing/tiny-random-unispeech",
101101
"unispeech_sat": "hf-internal-testing/tiny-random-UnispeechSatModel",
102102
"vit": "hf-internal-testing/tiny-random-vit",
103+
"vit-with-attentions": "IlyasMoutawwakil/vit-with-attentions",
104+
"vit-with-hidden-states": "IlyasMoutawwakil/vit-with-hidden_states",
103105
"vision-encoder-decoder": "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2",
104106
"wavlm": "hf-internal-testing/tiny-random-WavlmModel",
105107
"wav2vec2": "anton-l/wav2vec2-random-tiny-classifier",

0 commit comments

Comments
 (0)