diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py index c097562651..dec39c75db 100644 --- a/optimum/intel/__init__.py +++ b/optimum/intel/__init__.py @@ -112,8 +112,9 @@ "OVModelForAudioClassification", "OVModelForAudioFrameClassification", "OVModelForAudioXVector", - "OVModelForCTC", "OVModelForCausalLM", + "OVModelForCTC", + "OVModelForCustomTasks", "OVModelForFeatureExtraction", "OVModelForImageClassification", "OVModelForMaskedLM", @@ -235,6 +236,7 @@ OVModelForAudioXVector, OVModelForCausalLM, OVModelForCTC, + OVModelForCustomTasks, OVModelForFeatureExtraction, OVModelForImageClassification, OVModelForMaskedLM, diff --git a/optimum/intel/openvino/__init__.py b/optimum/intel/openvino/__init__.py index 0cd7d8a029..b871668588 100644 --- a/optimum/intel/openvino/__init__.py +++ b/optimum/intel/openvino/__init__.py @@ -49,6 +49,7 @@ OVModelForAudioFrameClassification, OVModelForAudioXVector, OVModelForCTC, + OVModelForCustomTasks, OVModelForFeatureExtraction, OVModelForImageClassification, OVModelForMaskedLM, diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py index 8a816609fa..9c7c2b5258 100644 --- a/optimum/intel/openvino/modeling.py +++ b/optimum/intel/openvino/modeling.py @@ -43,6 +43,7 @@ CausalLMOutput, ImageClassifierOutput, MaskedLMOutput, + ModelOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput, @@ -953,3 +954,66 @@ def forward( logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"] return TokenClassifierOutput(logits=logits) + + +CUSTOM_TASKS_EXAMPLE = """ + Example of custom tasks (e.g. a sentence transformers with a pooler head): + + ```python + >>> from transformers import {processor_class} + >>> from optimum.intel import {model_class} + + >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("I love burritos!", return_tensors="np") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooler_output = outputs.pooler_output + ``` +""" + + +@add_start_docstrings( + """ + OpenVINO Model for custom tasks. It can be used to leverage the inference acceleration for any single-file OpenVINO model, that may use custom inputs and outputs. + """, + MODEL_START_DOCSTRING, +) +class OVModelForCustomTasks(OVModel): + @add_start_docstrings_to_model_forward( + CUSTOM_TASKS_EXAMPLE.format( + processor_class=_TOKENIZER_FOR_DOC, + model_class="OVModelForCustomTasks", + checkpoint="IlyasMoutawwakil/sbert-all-MiniLM-L6-v2-with-pooler", + ) + ) + def forward(self, **kwargs): + expected_inputs_names = set(self.input_names) + inputs_names = set(kwargs) + + if not expected_inputs_names.issubset(inputs_names): + raise ValueError( + f"Got unexpected inputs: expecting the following inputs : {', '.join(expected_inputs_names)} but got : {', '.join(inputs_names)}." + ) + + np_inputs = isinstance(next(iter(kwargs.values())), np.ndarray) + inputs = {} + for input_name in self.input_names: + inputs[input_name] = np.array(kwargs.pop(input_name)) if not np_inputs else kwargs.pop(input_name) + + outputs = self.request(inputs) + + model_outputs = {} + for key, value in outputs.items(): + key_name = next(iter(key.names)) + if "." in key_name: + key_name = key_name.split(".")[0] + if key_name not in model_outputs: + model_outputs[key_name] = [] + model_outputs[key_name].append(torch.from_numpy(value).to(self.device) if not np_inputs else value) + else: + model_outputs[key_name] = torch.from_numpy(value).to(self.device) if not np_inputs else value + + return ModelOutput(**model_outputs) diff --git a/tests/openvino/test_export.py b/tests/openvino/test_export.py index 21bec021f8..9d1daaab63 100644 --- a/tests/openvino/test_export.py +++ b/tests/openvino/test_export.py @@ -19,15 +19,18 @@ from typing import Optional from parameterized import parameterized +from transformers import AutoConfig from utils_tests import MODEL_NAMES from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED -from optimum.exporters.openvino import export_from_model +from optimum.exporters.onnx.model_configs import BertOnnxConfig +from optimum.exporters.openvino import export_from_model, main_export from optimum.exporters.tasks import TasksManager from optimum.intel import ( OVLatentConsistencyModelPipeline, OVModelForAudioClassification, OVModelForCausalLM, + OVModelForCustomTasks, OVModelForFeatureExtraction, OVModelForImageClassification, OVModelForMaskedLM, @@ -114,3 +117,39 @@ def _openvino_export( @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_export(self, model_type: str): self._openvino_export(model_type) + + +class CustomExportModelTest(unittest.TestCase): + def test_export_custom_model(self): + class BertOnnxConfigWithPooler(BertOnnxConfig): + @property + def outputs(self): + if self.task == "feature-extraction-with-pooler": + common_outputs = {} + common_outputs["last_hidden_state"] = {0: "batch_size", 1: "sequence_length"} + common_outputs["pooler_output"] = {0: "batch_size"} + else: + common_outputs = super().outputs + + return common_outputs + + base_task = "feature-extraction" + custom_task = f"{base_task}-with-pooler" + model_id = "sentence-transformers/all-MiniLM-L6-v2" + + config = AutoConfig.from_pretrained(model_id) + custom_export_configs = {"model": BertOnnxConfigWithPooler(config, task=custom_task)} + + with TemporaryDirectory() as tmpdirname: + main_export( + model_name_or_path=model_id, + custom_export_configs=custom_export_configs, + library_name="transformers", + output=Path(tmpdirname), + task=base_task, + ) + + ov_model = OVModelForCustomTasks.from_pretrained(tmpdirname) + + self.assertIsInstance(ov_model, OVBaseModel) + self.assertTrue(ov_model.output_names == {"last_hidden_state": 0, "pooler_output": 1}) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 907c767310..f84cac8161 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -63,6 +63,7 @@ OVModelForAudioXVector, OVModelForCausalLM, OVModelForCTC, + OVModelForCustomTasks, OVModelForFeatureExtraction, OVModelForImageClassification, OVModelForMaskedLM, @@ -1525,3 +1526,87 @@ def test_pipeline_image_to_text(self, model_arch: str): self.assertIsInstance(outputs[0]["generated_text"], str) gc.collect() + + +class OVModelForCustomTasksIntegrationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES_WITH_ATTENTION = ["vit-with-attentions"] + SUPPORTED_ARCHITECTURES_WITH_HIDDEN_STATES = ["vit-with-hidden-states"] + + def _get_sample_image(self): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + return image + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_ATTENTION) + def test_compare_output_attentions(self, model_arch): + model_id = MODEL_NAMES[model_arch] + + image = self._get_sample_image() + preprocessor = AutoFeatureExtractor.from_pretrained(model_id) + inputs = preprocessor(images=image, return_tensors="pt") + + transformers_model = AutoModelForImageClassification.from_pretrained(model_id) + transformers_model.eval() + with torch.no_grad(): + transformers_outputs = transformers_model(**inputs, output_attentions=True) + + ov_model = OVModelForCustomTasks.from_pretrained(model_id, ov_config=F32_CONFIG) + self.assertIsInstance(ov_model.config, PretrainedConfig) + + for input_type in ["pt", "np"]: + inputs = preprocessor(images=image, return_tensors=input_type) + ov_outputs = ov_model(**inputs) + self.assertIn("logits", ov_outputs) + self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type]) + self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4)) + self.assertTrue(len(ov_outputs.attentions) == len(transformers_outputs.attentions)) + for i in range(len(ov_outputs.attentions)): + self.assertTrue( + torch.allclose( + torch.Tensor(ov_outputs.attentions[i]), + transformers_outputs.attentions[i], + atol=1e-4, # attentions are accurate + rtol=1e-4, # attentions are accurate + ), + f"Attention mismatch at layer {i}", + ) + + del transformers_model + del ov_model + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HIDDEN_STATES) + def test_compare_output_hidden_states(self, model_arch): + model_id = MODEL_NAMES[model_arch] + + image = self._get_sample_image() + preprocessor = AutoFeatureExtractor.from_pretrained(model_id) + inputs = preprocessor(images=image, return_tensors="pt") + + transformers_model = AutoModelForImageClassification.from_pretrained(model_id) + transformers_model.eval() + with torch.no_grad(): + transformers_outputs = transformers_model(**inputs, output_hidden_states=True) + + ov_model = OVModelForCustomTasks.from_pretrained(model_id, ov_config=F32_CONFIG) + self.assertIsInstance(ov_model.config, PretrainedConfig) + for input_type in ["pt", "np"]: + inputs = preprocessor(images=image, return_tensors=input_type) + ov_outputs = ov_model(**inputs) + self.assertIn("logits", ov_outputs) + self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type]) + self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4)) + self.assertTrue(len(ov_outputs.hidden_states) == len(transformers_outputs.hidden_states)) + for i in range(len(ov_outputs.hidden_states)): + self.assertTrue( + torch.allclose( + torch.Tensor(ov_outputs.hidden_states[i]), + transformers_outputs.hidden_states[i], + atol=1e-3, # hidden states are less accurate + rtol=1e-2, # hidden states are less accurate + ), + f"Hidden states mismatch at layer {i}", + ) + del transformers_model + del ov_model + gc.collect() diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 73224c81b2..c610479dd7 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -100,6 +100,8 @@ "unispeech": "hf-internal-testing/tiny-random-unispeech", "unispeech_sat": "hf-internal-testing/tiny-random-UnispeechSatModel", "vit": "hf-internal-testing/tiny-random-vit", + "vit-with-attentions": "IlyasMoutawwakil/vit-with-attentions", + "vit-with-hidden-states": "IlyasMoutawwakil/vit-with-hidden_states", "vision-encoder-decoder": "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2", "wavlm": "hf-internal-testing/tiny-random-WavlmModel", "wav2vec2": "anton-l/wav2vec2-random-tiny-classifier",