From c71b411be48e721accb7c8764b206c9e659417d1 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 17 Apr 2024 08:51:20 +0200 Subject: [PATCH 01/12] added custom tasks modeling --- optimum/intel/__init__.py | 4 ++- optimum/intel/openvino/__init__.py | 1 + optimum/intel/openvino/modeling.py | 55 ++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py index c097562651..dec39c75db 100644 --- a/optimum/intel/__init__.py +++ b/optimum/intel/__init__.py @@ -112,8 +112,9 @@ "OVModelForAudioClassification", "OVModelForAudioFrameClassification", "OVModelForAudioXVector", - "OVModelForCTC", "OVModelForCausalLM", + "OVModelForCTC", + "OVModelForCustomTasks", "OVModelForFeatureExtraction", "OVModelForImageClassification", "OVModelForMaskedLM", @@ -235,6 +236,7 @@ OVModelForAudioXVector, OVModelForCausalLM, OVModelForCTC, + OVModelForCustomTasks, OVModelForFeatureExtraction, OVModelForImageClassification, OVModelForMaskedLM, diff --git a/optimum/intel/openvino/__init__.py b/optimum/intel/openvino/__init__.py index 0cd7d8a029..b871668588 100644 --- a/optimum/intel/openvino/__init__.py +++ b/optimum/intel/openvino/__init__.py @@ -49,6 +49,7 @@ OVModelForAudioFrameClassification, OVModelForAudioXVector, OVModelForCTC, + OVModelForCustomTasks, OVModelForFeatureExtraction, OVModelForImageClassification, OVModelForMaskedLM, diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py index 8a816609fa..fb68f71cfb 100644 --- a/optimum/intel/openvino/modeling.py +++ b/optimum/intel/openvino/modeling.py @@ -43,6 +43,7 @@ CausalLMOutput, ImageClassifierOutput, MaskedLMOutput, + ModelOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput, @@ -953,3 +954,57 @@ def forward( logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"] return TokenClassifierOutput(logits=logits) + + +CUSTOM_TASKS_EXAMPLE = """ + Example of custom tasks (e.g. a sentence transformers taking `pooler_output` as output): + + ```python + >>> from transformers import {processor_class} + >>> from optimum.intel import {model_class} + + >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("I love burritos!", return_tensors="np") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooler_output = outputs.pooler_output + ``` +""" + + +@add_start_docstrings( + """ + OpenVINO Model for custom tasks. + """, + MODEL_START_DOCSTRING, +) +class OVModelForCustomTasks(OVModel): + """ + OpenVINO Model for any custom tasks. It can be used to leverage the inference acceleration for any single-file ONNX model, that may use custom inputs and outputs. + """ + + @add_start_docstrings_to_model_forward( + CUSTOM_TASKS_EXAMPLE.format( + processor_class=_TOKENIZER_FOR_DOC, + model_class="OVModelForCustomTasks", + checkpoint="sentence-transformers/all-MiniLM-L6-v2", + ) + ) + def forward(self, **kwargs): + np_inputs = isinstance(next(iter(kwargs.values())), np.ndarray) + + inputs = {} + + for key, value in kwargs.items(): + inputs[key] = np.array(value) if not np_inputs else value + + # Run inference + outputs = self.request(inputs) + + for key, value in outputs.items(): + outputs[key] = torch.from_numpy(value).to(self.device) if not np_inputs else value + + return ModelOutput(**outputs) From d282ae25bcbd2fee80a9fd19bd01b9756c3332ba Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 17 Apr 2024 14:51:34 +0200 Subject: [PATCH 02/12] patched output names for now and added vit with a attentions test --- optimum/intel/openvino/modeling.py | 10 ++++++-- optimum/intel/openvino/modeling_base.py | 4 +++ tests/openvino/test_modeling.py | 34 +++++++++++++++++++++++++ tests/openvino/utils_tests.py | 1 + 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py index fb68f71cfb..dad351aba8 100644 --- a/optimum/intel/openvino/modeling.py +++ b/optimum/intel/openvino/modeling.py @@ -1004,7 +1004,13 @@ def forward(self, **kwargs): # Run inference outputs = self.request(inputs) + model_outputs = {} for key, value in outputs.items(): - outputs[key] = torch.from_numpy(value).to(self.device) if not np_inputs else value + if len(key.names) == 0: + key_names = {"no_name_output_O_o"} + else: + key_names = key.names - return ModelOutput(**outputs) + model_outputs[next(iter(key_names))] = torch.from_numpy(value).to(self.device) if not np_inputs else value + + return ModelOutput(**model_outputs) diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index d5b19bb28c..420810c1ae 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -85,6 +85,10 @@ def __init__( output_names = {} for idx, key in enumerate(model.outputs): names = tuple(key.get_names()) + + if len(names) == 0: + names = ("no_name_output_O_o",) + output_names[next((name for name in names if "/" not in name), names[0])] = idx self.output_names = output_names diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 907c767310..1ca0521187 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -63,6 +63,7 @@ OVModelForAudioXVector, OVModelForCausalLM, OVModelForCTC, + OVModelForCustomTasks, OVModelForFeatureExtraction, OVModelForImageClassification, OVModelForMaskedLM, @@ -1525,3 +1526,36 @@ def test_pipeline_image_to_text(self, model_arch: str): self.assertIsInstance(outputs[0]["generated_text"], str) gc.collect() + + +class OVModelForCustomTasksIntegrationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES = ["vit-with-attentions"] + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + ov_model = OVModelForCustomTasks.from_pretrained(model_id, ov_config=F32_CONFIG) + self.assertIsInstance(ov_model.config, PretrainedConfig) + transformers_model = AutoModelForImageClassification.from_pretrained(model_id) + preprocessor = AutoFeatureExtractor.from_pretrained(model_id) + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + inputs = preprocessor(images=image, return_tensors="pt") + + # with torch.no_grad(): + # transformers_outputs = transformers_model(**inputs, output_attentions=True) + + for input_type in ["pt", "np"]: + inputs = preprocessor(images=image, return_tensors=input_type) + ov_outputs = ov_model(**inputs) + self.assertIn("logits", ov_outputs) + self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type]) + # Compare tensor outputs + # self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4)) + # self.assertTrue( + # torch.allclose(torch.Tensor(ov_outputs.attentions), transformers_outputs.attentions, atol=1e-4) + # ) + del transformers_model + del ov_model + gc.collect() diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 73224c81b2..b2dada248c 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -100,6 +100,7 @@ "unispeech": "hf-internal-testing/tiny-random-unispeech", "unispeech_sat": "hf-internal-testing/tiny-random-UnispeechSatModel", "vit": "hf-internal-testing/tiny-random-vit", + "vit-with-attentions": "IlyasMoutawwakil/vit-with-attentions", "vision-encoder-decoder": "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2", "wavlm": "hf-internal-testing/tiny-random-WavlmModel", "wav2vec2": "anton-l/wav2vec2-random-tiny-classifier", From d14ac621ff0f644e2d79cc3ff2da61f2f86b9a2a Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 17 Apr 2024 15:42:35 +0200 Subject: [PATCH 03/12] test passing --- tests/openvino/test_modeling.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 1ca0521187..7697ca6444 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -1542,20 +1542,17 @@ def test_compare_to_transformers(self, model_arch): url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) inputs = preprocessor(images=image, return_tensors="pt") - - # with torch.no_grad(): - # transformers_outputs = transformers_model(**inputs, output_attentions=True) - + with torch.no_grad(): + transformers_outputs = transformers_model(**inputs, output_attentions=True) for input_type in ["pt", "np"]: inputs = preprocessor(images=image, return_tensors=input_type) ov_outputs = ov_model(**inputs) self.assertIn("logits", ov_outputs) self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type]) - # Compare tensor outputs - # self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4)) - # self.assertTrue( - # torch.allclose(torch.Tensor(ov_outputs.attentions), transformers_outputs.attentions, atol=1e-4) - # ) + self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4)) + self.assertTrue( + torch.allclose(torch.Tensor(ov_outputs.attentions), transformers_outputs.attentions, atol=1e-4) + ) del transformers_model del ov_model gc.collect() From 8bf40fada9b3c7fd092ed49c22ccfa7031e1d898 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 17 Apr 2024 16:37:08 +0200 Subject: [PATCH 04/12] fix attentions --- tests/openvino/test_modeling.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 7697ca6444..5b770fbb37 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -1551,7 +1551,12 @@ def test_compare_to_transformers(self, model_arch): self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type]) self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4)) self.assertTrue( - torch.allclose(torch.Tensor(ov_outputs.attentions), transformers_outputs.attentions, atol=1e-4) + all( + torch.allclose( + torch.Tensor(ov_outputs.attentions[i]), transformers_outputs.attentions[i], atol=1e-4 + ) + for i in range(len(ov_outputs.attentions)) + ) ) del transformers_model del ov_model From eddccab3067cc3b24b34c84c23f0e52ccb23e7dc Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 17 Apr 2024 17:36:05 +0200 Subject: [PATCH 05/12] added hidden states test --- tests/openvino/test_modeling.py | 60 ++++++++++++++++++++++++++++----- tests/openvino/utils_tests.py | 1 + 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 5b770fbb37..ceccd31ae1 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -1529,21 +1529,30 @@ def test_pipeline_image_to_text(self, model_arch: str): class OVModelForCustomTasksIntegrationTest(unittest.TestCase): - SUPPORTED_ARCHITECTURES = ["vit-with-attentions"] + SUPPORTED_ARCHITECTURES_WITH_ATTENTION = ["vit-with-attentions"] + SUPPORTED_ARCHITECTURES_WITH_HIDDEN_STATES = ["vit-with-hidden-states"] - @parameterized.expand(SUPPORTED_ARCHITECTURES) - def test_compare_to_transformers(self, model_arch): - model_id = MODEL_NAMES[model_arch] - set_seed(SEED) - ov_model = OVModelForCustomTasks.from_pretrained(model_id, ov_config=F32_CONFIG) - self.assertIsInstance(ov_model.config, PretrainedConfig) - transformers_model = AutoModelForImageClassification.from_pretrained(model_id) - preprocessor = AutoFeatureExtractor.from_pretrained(model_id) + def _get_sample_image(self): url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) + return image + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_ATTENTION) + def test_compare_output_attentions(self, model_arch): + model_id = MODEL_NAMES[model_arch] + + image = self._get_sample_image() + preprocessor = AutoFeatureExtractor.from_pretrained(model_id) inputs = preprocessor(images=image, return_tensors="pt") + + transformers_model = AutoModelForImageClassification.from_pretrained(model_id) + transformers_model.eval() with torch.no_grad(): transformers_outputs = transformers_model(**inputs, output_attentions=True) + + ov_model = OVModelForCustomTasks.from_pretrained(model_id, ov_config=F32_CONFIG) + self.assertIsInstance(ov_model.config, PretrainedConfig) + for input_type in ["pt", "np"]: inputs = preprocessor(images=image, return_tensors=input_type) ov_outputs = ov_model(**inputs) @@ -1561,3 +1570,36 @@ def test_compare_to_transformers(self, model_arch): del transformers_model del ov_model gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HIDDEN_STATES) + def test_compare_output_hidden_states(self, model_arch): + model_id = MODEL_NAMES[model_arch] + + image = self._get_sample_image() + preprocessor = AutoFeatureExtractor.from_pretrained(model_id) + inputs = preprocessor(images=image, return_tensors="pt") + + transformers_model = AutoModelForImageClassification.from_pretrained(model_id) + transformers_model.eval() + with torch.no_grad(): + transformers_outputs = transformers_model(**inputs, output_hidden_states=True) + + ov_model = OVModelForCustomTasks.from_pretrained(model_id, ov_config=F32_CONFIG) + self.assertIsInstance(ov_model.config, PretrainedConfig) + for input_type in ["pt", "np"]: + inputs = preprocessor(images=image, return_tensors=input_type) + ov_outputs = ov_model(**inputs) + self.assertIn("logits", ov_outputs) + self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type]) + self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4)) + self.assertTrue( + all( + torch.allclose( + torch.Tensor(ov_outputs.hidden_states[i]), transformers_outputs.hidden_states[i], atol=1e-4 + ) + for i in range(len(ov_outputs.hidden_states)) + ) + ) + del transformers_model + del ov_model + gc.collect() diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index b2dada248c..c610479dd7 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -101,6 +101,7 @@ "unispeech_sat": "hf-internal-testing/tiny-random-UnispeechSatModel", "vit": "hf-internal-testing/tiny-random-vit", "vit-with-attentions": "IlyasMoutawwakil/vit-with-attentions", + "vit-with-hidden-states": "IlyasMoutawwakil/vit-with-hidden_states", "vision-encoder-decoder": "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2", "wavlm": "hf-internal-testing/tiny-random-WavlmModel", "wav2vec2": "anton-l/wav2vec2-random-tiny-classifier", From 1ab12a4563a7d277bef54d544dcc2387b2d3b179 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 18 Apr 2024 11:42:50 +0200 Subject: [PATCH 06/12] remove unnecessary names processing --- optimum/intel/openvino/modeling.py | 14 +++++++------- optimum/intel/openvino/modeling_base.py | 4 ---- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py index dad351aba8..69af16563e 100644 --- a/optimum/intel/openvino/modeling.py +++ b/optimum/intel/openvino/modeling.py @@ -997,20 +997,20 @@ def forward(self, **kwargs): np_inputs = isinstance(next(iter(kwargs.values())), np.ndarray) inputs = {} - for key, value in kwargs.items(): inputs[key] = np.array(value) if not np_inputs else value - # Run inference outputs = self.request(inputs) model_outputs = {} for key, value in outputs.items(): - if len(key.names) == 0: - key_names = {"no_name_output_O_o"} + key_name = next(iter(key.names)) + if "." in key_name: + key_name = key_name.split(".")[0] + if key_name not in model_outputs: + model_outputs[key_name] = [] + model_outputs[key_name].append(torch.from_numpy(value).to(self.device) if not np_inputs else value) else: - key_names = key.names - - model_outputs[next(iter(key_names))] = torch.from_numpy(value).to(self.device) if not np_inputs else value + model_outputs[key_name] = torch.from_numpy(value).to(self.device) if not np_inputs else value return ModelOutput(**model_outputs) diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 420810c1ae..d5b19bb28c 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -85,10 +85,6 @@ def __init__( output_names = {} for idx, key in enumerate(model.outputs): names = tuple(key.get_names()) - - if len(names) == 0: - names = ("no_name_output_O_o",) - output_names[next((name for name in names if "/" not in name), names[0])] = idx self.output_names = output_names From b284bc9b5f254ddfd515f42184b05266257bab24 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 18 Apr 2024 11:45:57 +0200 Subject: [PATCH 07/12] better testing --- tests/openvino/test_modeling.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index ceccd31ae1..f84cac8161 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -1559,14 +1559,18 @@ def test_compare_output_attentions(self, model_arch): self.assertIn("logits", ov_outputs) self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type]) self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4)) - self.assertTrue( - all( + self.assertTrue(len(ov_outputs.attentions) == len(transformers_outputs.attentions)) + for i in range(len(ov_outputs.attentions)): + self.assertTrue( torch.allclose( - torch.Tensor(ov_outputs.attentions[i]), transformers_outputs.attentions[i], atol=1e-4 - ) - for i in range(len(ov_outputs.attentions)) + torch.Tensor(ov_outputs.attentions[i]), + transformers_outputs.attentions[i], + atol=1e-4, # attentions are accurate + rtol=1e-4, # attentions are accurate + ), + f"Attention mismatch at layer {i}", ) - ) + del transformers_model del ov_model gc.collect() @@ -1592,14 +1596,17 @@ def test_compare_output_hidden_states(self, model_arch): self.assertIn("logits", ov_outputs) self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type]) self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4)) - self.assertTrue( - all( + self.assertTrue(len(ov_outputs.hidden_states) == len(transformers_outputs.hidden_states)) + for i in range(len(ov_outputs.hidden_states)): + self.assertTrue( torch.allclose( - torch.Tensor(ov_outputs.hidden_states[i]), transformers_outputs.hidden_states[i], atol=1e-4 - ) - for i in range(len(ov_outputs.hidden_states)) + torch.Tensor(ov_outputs.hidden_states[i]), + transformers_outputs.hidden_states[i], + atol=1e-3, # hidden states are less accurate + rtol=1e-2, # hidden states are less accurate + ), + f"Hidden states mismatch at layer {i}", ) - ) del transformers_model del ov_model gc.collect() From 17d42853bb487c6e331094f46b1300316a0821b5 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Fri, 19 Apr 2024 09:40:32 +0200 Subject: [PATCH 08/12] added inputs check Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> --- optimum/intel/openvino/modeling.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py index 69af16563e..29a8335ea2 100644 --- a/optimum/intel/openvino/modeling.py +++ b/optimum/intel/openvino/modeling.py @@ -994,11 +994,18 @@ class OVModelForCustomTasks(OVModel): ) ) def forward(self, **kwargs): - np_inputs = isinstance(next(iter(kwargs.values())), np.ndarray) + expected_inputs_names = set(self.input_names) + inputs_names = set(kwargs) + if not expected_inputs_names.issubset(inputs_names): + raise ValueError( + f"Got unexpected inputs: expecting the following inputs : {', '.join(expected_inputs_names)} but got : {', '.join(inputs_names)}." + ) + + np_inputs = isinstance(next(iter(kwargs.values())), np.ndarray) inputs = {} - for key, value in kwargs.items(): - inputs[key] = np.array(value) if not np_inputs else value + for input_name in self.input_names: + inputs[input_name] = np.array(kwargs.pop(input_name)) if not np_inputs else kwargs.pop(input_name) outputs = self.request(inputs) From b1d2223dd727ae2a26b3cb6f1b5e09e567893578 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 19 Apr 2024 10:15:46 +0200 Subject: [PATCH 09/12] added a bert with pooler --- optimum/intel/openvino/modeling.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py index 29a8335ea2..55fd9d5f37 100644 --- a/optimum/intel/openvino/modeling.py +++ b/optimum/intel/openvino/modeling.py @@ -957,7 +957,7 @@ def forward( CUSTOM_TASKS_EXAMPLE = """ - Example of custom tasks (e.g. a sentence transformers taking `pooler_output` as output): + Example of custom tasks (e.g. a sentence transformers with a pooler head): ```python >>> from transformers import {processor_class} @@ -977,20 +977,16 @@ def forward( @add_start_docstrings( """ - OpenVINO Model for custom tasks. + OpenVINO Model for custom tasks. It can be used to leverage the inference acceleration for any single-file OpenVINO model, that may use custom inputs and outputs. """, MODEL_START_DOCSTRING, ) class OVModelForCustomTasks(OVModel): - """ - OpenVINO Model for any custom tasks. It can be used to leverage the inference acceleration for any single-file ONNX model, that may use custom inputs and outputs. - """ - @add_start_docstrings_to_model_forward( CUSTOM_TASKS_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="OVModelForCustomTasks", - checkpoint="sentence-transformers/all-MiniLM-L6-v2", + checkpoint="IlyasMoutawwakil/bert-with-pooler", ) ) def forward(self, **kwargs): @@ -1001,7 +997,7 @@ def forward(self, **kwargs): raise ValueError( f"Got unexpected inputs: expecting the following inputs : {', '.join(expected_inputs_names)} but got : {', '.join(inputs_names)}." ) - + np_inputs = isinstance(next(iter(kwargs.values())), np.ndarray) inputs = {} for input_name in self.input_names: From 10e34cdf2ab8f7978336cca46afea1d8f619bbec Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 19 Apr 2024 10:23:49 +0200 Subject: [PATCH 10/12] fix name --- optimum/intel/openvino/modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py index 55fd9d5f37..9c7c2b5258 100644 --- a/optimum/intel/openvino/modeling.py +++ b/optimum/intel/openvino/modeling.py @@ -986,7 +986,7 @@ class OVModelForCustomTasks(OVModel): CUSTOM_TASKS_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="OVModelForCustomTasks", - checkpoint="IlyasMoutawwakil/bert-with-pooler", + checkpoint="IlyasMoutawwakil/sbert-all-MiniLM-L6-v2-with-pooler", ) ) def forward(self, **kwargs): From 3ac9f99880c0078dce63d0cdda2896dcbf66261e Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 19 Apr 2024 10:42:32 +0200 Subject: [PATCH 11/12] added a custom export test --- tests/openvino/test_export.py | 37 ++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/openvino/test_export.py b/tests/openvino/test_export.py index 21bec021f8..e28fa58801 100644 --- a/tests/openvino/test_export.py +++ b/tests/openvino/test_export.py @@ -19,15 +19,18 @@ from typing import Optional from parameterized import parameterized +from transformers import AutoConfig from utils_tests import MODEL_NAMES from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED -from optimum.exporters.openvino import export_from_model +from optimum.exporters.onnx.model_configs import BertOnnxConfig +from optimum.exporters.openvino import export_from_model, main_export from optimum.exporters.tasks import TasksManager from optimum.intel import ( OVLatentConsistencyModelPipeline, OVModelForAudioClassification, OVModelForCausalLM, + OVModelForCustomTasks, OVModelForFeatureExtraction, OVModelForImageClassification, OVModelForMaskedLM, @@ -114,3 +117,35 @@ def _openvino_export( @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_export(self, model_type: str): self._openvino_export(model_type) + + +class CustomExportModelTest(unittest.TestCase): + def test_export_custom_model(self): + class BertOnnxConfigWithPooler(BertOnnxConfig): + @property + def outputs(self): + common_outputs = {} + common_outputs["last_hidden_state"] = {0: "batch_size", 1: "sequence_length"} + common_outputs["pooler_output"] = {0: "batch_size"} + return common_outputs + + base_task = "feature-extraction" + custom_task = f"{base_task}-with-pooler" + model_id = "sentence-transformers/all-MiniLM-L6-v2" + + config = AutoConfig.from_pretrained(model_id) + custom_export_configs = {"model": BertOnnxConfigWithPooler(config, task=base_task)} + + with TemporaryDirectory() as tmpdirname: + main_export( + model_name_or_path=model_id, + custom_export_configs=custom_export_configs, + library_name="transformers", + output=Path(tmpdirname), + task=base_task, + ) + + ov_model = OVModelForCustomTasks.from_pretrained(tmpdirname) + + self.assertIsInstance(ov_model, OVBaseModel) + self.assertTrue(ov_model.output_names == {"last_hidden_state": 0, "pooler_output": 1}) From 360ad674eae01d44aabce653a0ecece66f449ec2 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 19 Apr 2024 10:47:52 +0200 Subject: [PATCH 12/12] better custom config --- tests/openvino/test_export.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/openvino/test_export.py b/tests/openvino/test_export.py index e28fa58801..9d1daaab63 100644 --- a/tests/openvino/test_export.py +++ b/tests/openvino/test_export.py @@ -124,9 +124,13 @@ def test_export_custom_model(self): class BertOnnxConfigWithPooler(BertOnnxConfig): @property def outputs(self): - common_outputs = {} - common_outputs["last_hidden_state"] = {0: "batch_size", 1: "sequence_length"} - common_outputs["pooler_output"] = {0: "batch_size"} + if self.task == "feature-extraction-with-pooler": + common_outputs = {} + common_outputs["last_hidden_state"] = {0: "batch_size", 1: "sequence_length"} + common_outputs["pooler_output"] = {0: "batch_size"} + else: + common_outputs = super().outputs + return common_outputs base_task = "feature-extraction" @@ -134,7 +138,7 @@ def outputs(self): model_id = "sentence-transformers/all-MiniLM-L6-v2" config = AutoConfig.from_pretrained(model_id) - custom_export_configs = {"model": BertOnnxConfigWithPooler(config, task=base_task)} + custom_export_configs = {"model": BertOnnxConfigWithPooler(config, task=custom_task)} with TemporaryDirectory() as tmpdirname: main_export(