Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom tasks modeling #669

Merged
merged 12 commits into from
Apr 19, 2024
4 changes: 3 additions & 1 deletion optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@
"OVModelForAudioClassification",
"OVModelForAudioFrameClassification",
"OVModelForAudioXVector",
"OVModelForCTC",
"OVModelForCausalLM",
"OVModelForCTC",
"OVModelForCustomTasks",
"OVModelForFeatureExtraction",
"OVModelForImageClassification",
"OVModelForMaskedLM",
Expand Down Expand Up @@ -235,6 +236,7 @@
OVModelForAudioXVector,
OVModelForCausalLM,
OVModelForCTC,
OVModelForCustomTasks,
OVModelForFeatureExtraction,
OVModelForImageClassification,
OVModelForMaskedLM,
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
OVModelForAudioFrameClassification,
OVModelForAudioXVector,
OVModelForCTC,
OVModelForCustomTasks,
OVModelForFeatureExtraction,
OVModelForImageClassification,
OVModelForMaskedLM,
Expand Down
61 changes: 61 additions & 0 deletions optimum/intel/openvino/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
CausalLMOutput,
ImageClassifierOutput,
MaskedLMOutput,
ModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
Expand Down Expand Up @@ -953,3 +954,63 @@ def forward(
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]

return TokenClassifierOutput(logits=logits)


CUSTOM_TASKS_EXAMPLE = """
Example of custom tasks (e.g. a sentence transformers taking `pooler_output` as output):

```python
>>> from transformers import {processor_class}
>>> from optimum.intel import {model_class}

>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")

>>> inputs = tokenizer("I love burritos!", return_tensors="np")

>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooler_output = outputs.pooler_output
```
"""


@add_start_docstrings(
"""
OpenVINO Model for custom tasks.
""",
MODEL_START_DOCSTRING,
)
class OVModelForCustomTasks(OVModel):
"""
OpenVINO Model for any custom tasks. It can be used to leverage the inference acceleration for any single-file ONNX model, that may use custom inputs and outputs.
"""

@add_start_docstrings_to_model_forward(
CUSTOM_TASKS_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
model_class="OVModelForCustomTasks",
checkpoint="sentence-transformers/all-MiniLM-L6-v2",
)
)
def forward(self, **kwargs):
np_inputs = isinstance(next(iter(kwargs.values())), np.ndarray)

inputs = {}

for key, value in kwargs.items():
inputs[key] = np.array(value) if not np_inputs else value

# Run inference
outputs = self.request(inputs)

model_outputs = {}
for key, value in outputs.items():
if len(key.names) == 0:
key_names = {"no_name_output_O_o"}
else:
key_names = key.names

model_outputs[next(iter(key_names))] = torch.from_numpy(value).to(self.device) if not np_inputs else value

return ModelOutput(**model_outputs)
4 changes: 4 additions & 0 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def __init__(
output_names = {}
for idx, key in enumerate(model.outputs):
names = tuple(key.get_names())

if len(names) == 0:
names = ("no_name_output_O_o",)

output_names[next((name for name in names if "/" not in name), names[0])] = idx
self.output_names = output_names

Expand Down
34 changes: 34 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
OVModelForAudioXVector,
OVModelForCausalLM,
OVModelForCTC,
OVModelForCustomTasks,
OVModelForFeatureExtraction,
OVModelForImageClassification,
OVModelForMaskedLM,
Expand Down Expand Up @@ -1525,3 +1526,36 @@ def test_pipeline_image_to_text(self, model_arch: str):
self.assertIsInstance(outputs[0]["generated_text"], str)

gc.collect()


class OVModelForCustomTasksIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ["vit-with-attentions"]

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
ov_model = OVModelForCustomTasks.from_pretrained(model_id, ov_config=F32_CONFIG)
self.assertIsInstance(ov_model.config, PretrainedConfig)
transformers_model = AutoModelForImageClassification.from_pretrained(model_id)
preprocessor = AutoFeatureExtractor.from_pretrained(model_id)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = preprocessor(images=image, return_tensors="pt")

# with torch.no_grad():
# transformers_outputs = transformers_model(**inputs, output_attentions=True)

for input_type in ["pt", "np"]:
inputs = preprocessor(images=image, return_tensors=input_type)
ov_outputs = ov_model(**inputs)
self.assertIn("logits", ov_outputs)
self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type])
# Compare tensor outputs
# self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4))
# self.assertTrue(
# torch.allclose(torch.Tensor(ov_outputs.attentions), transformers_outputs.attentions, atol=1e-4)
# )
del transformers_model
del ov_model
gc.collect()
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
"unispeech": "hf-internal-testing/tiny-random-unispeech",
"unispeech_sat": "hf-internal-testing/tiny-random-UnispeechSatModel",
"vit": "hf-internal-testing/tiny-random-vit",
"vit-with-attentions": "IlyasMoutawwakil/vit-with-attentions",
"vision-encoder-decoder": "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2",
"wavlm": "hf-internal-testing/tiny-random-WavlmModel",
"wav2vec2": "anton-l/wav2vec2-random-tiny-classifier",
Expand Down
Loading