Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom tasks modeling #669

Merged
merged 12 commits into from
Apr 19, 2024
4 changes: 3 additions & 1 deletion optimum/intel/__init__.py
Original file line number Diff line number Diff line change
@@ -112,8 +112,9 @@
"OVModelForAudioClassification",
"OVModelForAudioFrameClassification",
"OVModelForAudioXVector",
"OVModelForCTC",
"OVModelForCausalLM",
"OVModelForCTC",
"OVModelForCustomTasks",
"OVModelForFeatureExtraction",
"OVModelForImageClassification",
"OVModelForMaskedLM",
@@ -235,6 +236,7 @@
OVModelForAudioXVector,
OVModelForCausalLM,
OVModelForCTC,
OVModelForCustomTasks,
OVModelForFeatureExtraction,
OVModelForImageClassification,
OVModelForMaskedLM,
1 change: 1 addition & 0 deletions optimum/intel/openvino/__init__.py
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@
OVModelForAudioFrameClassification,
OVModelForAudioXVector,
OVModelForCTC,
OVModelForCustomTasks,
OVModelForFeatureExtraction,
OVModelForImageClassification,
OVModelForMaskedLM,
61 changes: 61 additions & 0 deletions optimum/intel/openvino/modeling.py
Original file line number Diff line number Diff line change
@@ -43,6 +43,7 @@
CausalLMOutput,
ImageClassifierOutput,
MaskedLMOutput,
ModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
@@ -953,3 +954,63 @@ def forward(
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]

return TokenClassifierOutput(logits=logits)


CUSTOM_TASKS_EXAMPLE = """
Example of custom tasks (e.g. a sentence transformers taking `pooler_output` as output):

```python
>>> from transformers import {processor_class}
>>> from optimum.intel import {model_class}

>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")

>>> inputs = tokenizer("I love burritos!", return_tensors="np")

>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooler_output = outputs.pooler_output
```
"""


@add_start_docstrings(
"""
OpenVINO Model for custom tasks.
""",
MODEL_START_DOCSTRING,
)
class OVModelForCustomTasks(OVModel):
"""
OpenVINO Model for any custom tasks. It can be used to leverage the inference acceleration for any single-file ONNX model, that may use custom inputs and outputs.
"""

@add_start_docstrings_to_model_forward(
CUSTOM_TASKS_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
model_class="OVModelForCustomTasks",
checkpoint="sentence-transformers/all-MiniLM-L6-v2",
)
)
def forward(self, **kwargs):
np_inputs = isinstance(next(iter(kwargs.values())), np.ndarray)

inputs = {}
for key, value in kwargs.items():
inputs[key] = np.array(value) if not np_inputs else value

outputs = self.request(inputs)

model_outputs = {}
for key, value in outputs.items():
key_name = next(iter(key.names))
if "." in key_name:
key_name = key_name.split(".")[0]
if key_name not in model_outputs:
model_outputs[key_name] = []
model_outputs[key_name].append(torch.from_numpy(value).to(self.device) if not np_inputs else value)
else:
model_outputs[key_name] = torch.from_numpy(value).to(self.device) if not np_inputs else value

return ModelOutput(**model_outputs)
85 changes: 85 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
@@ -63,6 +63,7 @@
OVModelForAudioXVector,
OVModelForCausalLM,
OVModelForCTC,
OVModelForCustomTasks,
OVModelForFeatureExtraction,
OVModelForImageClassification,
OVModelForMaskedLM,
@@ -1525,3 +1526,87 @@ def test_pipeline_image_to_text(self, model_arch: str):
self.assertIsInstance(outputs[0]["generated_text"], str)

gc.collect()


class OVModelForCustomTasksIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES_WITH_ATTENTION = ["vit-with-attentions"]
SUPPORTED_ARCHITECTURES_WITH_HIDDEN_STATES = ["vit-with-hidden-states"]

def _get_sample_image(self):
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
return image

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_ATTENTION)
def test_compare_output_attentions(self, model_arch):
model_id = MODEL_NAMES[model_arch]

image = self._get_sample_image()
preprocessor = AutoFeatureExtractor.from_pretrained(model_id)
inputs = preprocessor(images=image, return_tensors="pt")

transformers_model = AutoModelForImageClassification.from_pretrained(model_id)
transformers_model.eval()
with torch.no_grad():
transformers_outputs = transformers_model(**inputs, output_attentions=True)

ov_model = OVModelForCustomTasks.from_pretrained(model_id, ov_config=F32_CONFIG)
self.assertIsInstance(ov_model.config, PretrainedConfig)

for input_type in ["pt", "np"]:
inputs = preprocessor(images=image, return_tensors=input_type)
ov_outputs = ov_model(**inputs)
self.assertIn("logits", ov_outputs)
self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type])
self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4))
self.assertTrue(len(ov_outputs.attentions) == len(transformers_outputs.attentions))
for i in range(len(ov_outputs.attentions)):
self.assertTrue(
torch.allclose(
torch.Tensor(ov_outputs.attentions[i]),
transformers_outputs.attentions[i],
atol=1e-4, # attentions are accurate
rtol=1e-4, # attentions are accurate
),
f"Attention mismatch at layer {i}",
)

del transformers_model
del ov_model
gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HIDDEN_STATES)
def test_compare_output_hidden_states(self, model_arch):
model_id = MODEL_NAMES[model_arch]

image = self._get_sample_image()
preprocessor = AutoFeatureExtractor.from_pretrained(model_id)
inputs = preprocessor(images=image, return_tensors="pt")

transformers_model = AutoModelForImageClassification.from_pretrained(model_id)
transformers_model.eval()
with torch.no_grad():
transformers_outputs = transformers_model(**inputs, output_hidden_states=True)

ov_model = OVModelForCustomTasks.from_pretrained(model_id, ov_config=F32_CONFIG)
self.assertIsInstance(ov_model.config, PretrainedConfig)
for input_type in ["pt", "np"]:
inputs = preprocessor(images=image, return_tensors=input_type)
ov_outputs = ov_model(**inputs)
self.assertIn("logits", ov_outputs)
self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type])
self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4))
self.assertTrue(len(ov_outputs.hidden_states) == len(transformers_outputs.hidden_states))
for i in range(len(ov_outputs.hidden_states)):
self.assertTrue(
torch.allclose(
torch.Tensor(ov_outputs.hidden_states[i]),
transformers_outputs.hidden_states[i],
atol=1e-3, # hidden states are less accurate
rtol=1e-2, # hidden states are less accurate
),
f"Hidden states mismatch at layer {i}",
)
del transformers_model
del ov_model
gc.collect()
2 changes: 2 additions & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
@@ -100,6 +100,8 @@
"unispeech": "hf-internal-testing/tiny-random-unispeech",
"unispeech_sat": "hf-internal-testing/tiny-random-UnispeechSatModel",
"vit": "hf-internal-testing/tiny-random-vit",
"vit-with-attentions": "IlyasMoutawwakil/vit-with-attentions",
"vit-with-hidden-states": "IlyasMoutawwakil/vit-with-hidden_states",
"vision-encoder-decoder": "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2",
"wavlm": "hf-internal-testing/tiny-random-WavlmModel",
"wav2vec2": "anton-l/wav2vec2-random-tiny-classifier",