diff --git a/optimum/intel/ipex/inference.py b/optimum/intel/ipex/inference.py index ccf2da9d80..a628ebe12e 100644 --- a/optimum/intel/ipex/inference.py +++ b/optimum/intel/ipex/inference.py @@ -97,6 +97,10 @@ def __init__( jit (`boolean = False`, *optional*): Enable jit to accelerate inference speed """ + logger.warning( + "`inference_mode` is deprecated and will be removed in v1.18.0. Use `pipeline` to load and export your model to TorchScript instead." + ) + if not is_ipex_available(): raise ImportError(IPEX_NOT_AVAILABLE_ERROR_MSG) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 2b739ea502..d2963d55a1 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -161,6 +161,7 @@ def _from_transformers( local_files_only: bool = False, torch_dtype: Optional[Union[str, "torch.dtype"]] = None, trust_remote_code: bool = False, + _commit_hash: str = None, ): if use_auth_token is not None: warnings.warn( @@ -186,6 +187,7 @@ def _from_transformers( "force_download": force_download, "torch_dtype": torch_dtype, "trust_remote_code": trust_remote_code, + "_commit_hash": _commit_hash, } model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) diff --git a/optimum/intel/pipelines/__init__.py b/optimum/intel/pipelines/__init__.py new file mode 100644 index 0000000000..40a1e3ca56 --- /dev/null +++ b/optimum/intel/pipelines/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pipeline_base import pipeline diff --git a/optimum/intel/pipelines/pipeline_base.py b/optimum/intel/pipelines/pipeline_base.py new file mode 100644 index 0000000000..65e6cfb782 --- /dev/null +++ b/optimum/intel/pipelines/pipeline_base.py @@ -0,0 +1,290 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +import torch +from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer +from transformers import pipeline as transformers_pipeline +from transformers.feature_extraction_utils import PreTrainedFeatureExtractor +from transformers.pipelines import ( + AudioClassificationPipeline, + FillMaskPipeline, + ImageClassificationPipeline, + QuestionAnsweringPipeline, + TextClassificationPipeline, + TextGenerationPipeline, + TokenClassificationPipeline, +) +from transformers.pipelines.base import Pipeline +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import logging + +from optimum.intel.utils import is_ipex_available + + +if is_ipex_available(): + from ..ipex.modeling_base import ( + IPEXModel, + IPEXModelForAudioClassification, + IPEXModelForCausalLM, + IPEXModelForImageClassification, + IPEXModelForMaskedLM, + IPEXModelForQuestionAnswering, + IPEXModelForSequenceClassification, + IPEXModelForTokenClassification, + ) + + IPEX_SUPPORTED_TASKS = { + "text-generation": { + "impl": TextGenerationPipeline, + "class": (IPEXModelForCausalLM,), + "default": "gpt2", + "type": "text", + }, + "fill-mask": { + "impl": FillMaskPipeline, + "class": (IPEXModelForMaskedLM,), + "default": "bert-base-cased", + "type": "text", + }, + "question-answering": { + "impl": QuestionAnsweringPipeline, + "class": (IPEXModelForQuestionAnswering,), + "default": "distilbert-base-cased-distilled-squad", + "type": "text", + }, + "image-classification": { + "impl": ImageClassificationPipeline, + "class": (IPEXModelForImageClassification,), + "default": "google/vit-base-patch16-224", + "type": "image", + }, + "text-classification": { + "impl": TextClassificationPipeline, + "class": (IPEXModelForSequenceClassification,), + "default": "distilbert-base-uncased-finetuned-sst-2-english", + "type": "text", + }, + "token-classification": { + "impl": TokenClassificationPipeline, + "class": (IPEXModelForTokenClassification,), + "default": "dbmdz/bert-large-cased-finetuned-conll03-english", + "type": "text", + }, + "audio-classification": { + "impl": AudioClassificationPipeline, + "class": (IPEXModelForAudioClassification,), + "default": "superb/hubert-base-superb-ks", + "type": "audio", + }, + } +else: + IPEX_SUPPORTED_TASKS = {} + + +def load_ipex_model( + model, + targeted_task, + SUPPORTED_TASKS, + model_kwargs: Optional[Dict[str, Any]] = None, + hub_kwargs: Optional[Dict[str, Any]] = None, +): + if model_kwargs is None: + model_kwargs = {} + + ipex_model_class = SUPPORTED_TASKS[targeted_task]["class"][0] + + if model is None: + model_id = SUPPORTED_TASKS[targeted_task]["default"] + model = ipex_model_class.from_pretrained(model_id, export=True, **model_kwargs, **hub_kwargs) + elif isinstance(model, str): + model_id = model + try: + config = AutoConfig.from_pretrained(model) + export = not getattr(config, "torchscript", False) + except RuntimeError: + logger.warning("We will use IPEXModel with export=True to export the model") + export = True + model = ipex_model_class.from_pretrained(model, export=export, **model_kwargs, **hub_kwargs) + elif isinstance(model, IPEXModel): + model_id = getattr(model.config, "name_or_path", None) + else: + raise ValueError( + f"""Model {model} is not supported. Please provide a valid model name or path or a IPEXModel. + You can also provide non model then a default one will be used""" + ) + + return model, model_id + + +MAPPING_LOADING_FUNC = { + "ipex": load_ipex_model, +} + + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + + +logger = logging.get_logger(__name__) + + +def pipeline( + task: str = None, + model: Optional[Union[str, "PreTrainedModel"]] = None, + tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, + feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, + use_fast: bool = True, + token: Optional[Union[str, bool]] = None, + accelerator: Optional[str] = "ort", + revision: Optional[str] = None, + trust_remote_code: Optional[bool] = None, + torch_dtype: Optional[Union[str, torch.dtype]] = None, + commit_hash: Optional[str] = None, + **model_kwargs, +) -> Pipeline: + """ + Utility factory method to build a [`Pipeline`]. + + Pipelines are made of: + + - A [tokenizer](tokenizer) in charge of mapping raw textual input to token. + - A [model](model) to make predictions from the inputs. + - Some (optional) post processing for enhancing model's output. + + Args: + task (`str`): + The task defining which pipeline will be returned. Currently accepted tasks are: + + - `"text-generation"`: will return a [`TextGenerationPipeline`]:. + + model (`str` or [`PreTrainedModel`], *optional*): + The model that will be used by the pipeline to make predictions. This can be a model identifier or an + actual instance of a pretrained model inheriting from [`PreTrainedModel`] (for PyTorch). + + If not provided, the default for the `task` will be loaded. + tokenizer (`str` or [`PreTrainedTokenizer`], *optional*): + The tokenizer that will be used by the pipeline to encode data for the model. This can be a model + identifier or an actual pretrained tokenizer inheriting from [`PreTrainedTokenizer`]. + + If not provided, the default tokenizer for the given `model` will be loaded (if it is a string). If `model` + is not specified or not a string, then the default tokenizer for `config` is loaded (if it is a string). + However, if `config` is also not given or not a string, then the default tokenizer for the given `task` + will be loaded. + accelerator (`str`, *optional*, defaults to `"ipex"`): + The optimization backends, choose from ["ipex", "inc", "openvino"]. + use_fast (`bool`, *optional*, defaults to `True`): + Whether or not to use a Fast tokenizer if possible (a [`PreTrainedTokenizerFast`]). + torch_dtype (`str` or `torch.dtype`, *optional*): + Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model + (`torch.float16`, `torch.bfloat16`, ... or `"auto"`). + model_kwargs (`Dict[str, Any]`, *optional*): + Additional dictionary of keyword arguments passed along to the model's `from_pretrained(..., + **model_kwargs)` function. + + Returns: + [`Pipeline`]: A suitable pipeline for the task. + + Examples: + + ```python + >>> import torch + >>> from optimum.intel.pipelines import pipeline + + >>> pipe = pipeline('text-generation', 'gpt2', torch_dtype=torch.bfloat16) + >>> pipe("Describe a real-world application of AI in sustainable energy.") + ```""" + if model_kwargs is None: + model_kwargs = {} + + if task is None and model is None: + raise RuntimeError( + "Impossible to instantiate a pipeline without either a task or a model " + "being specified. " + "Please provide a task class or a model" + ) + + if model is None and tokenizer is not None: + raise RuntimeError( + "Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided tokenizer" + " may not be compatible with the default model. Please provide a PreTrainedModel class or a" + " path/identifier to a pretrained model when providing tokenizer." + ) + + if accelerator not in MAPPING_LOADING_FUNC: + raise ValueError( + f'Accelerator {accelerator} is not supported. Supported accelerator is {", ".join(MAPPING_LOADING_FUNC)}.' + ) + + if accelerator == "ipex": + if task not in list(IPEX_SUPPORTED_TASKS.keys()): + raise ValueError( + f"Task {task} is not supported for the IPEX pipeline. Supported tasks are { list(IPEX_SUPPORTED_TASKS.keys())}" + ) + + supported_tasks = IPEX_SUPPORTED_TASKS if accelerator == "ipex" else None + + no_feature_extractor_tasks = set() + no_tokenizer_tasks = set() + for _task, values in supported_tasks.items(): + if values["type"] == "text": + no_feature_extractor_tasks.add(_task) + elif values["type"] in {"image", "video"}: + no_tokenizer_tasks.add(_task) + elif values["type"] in {"audio"}: + no_tokenizer_tasks.add(_task) + elif values["type"] not in ["multimodal", "audio", "video"]: + raise ValueError(f"SUPPORTED_TASK {_task} contains invalid type {values['type']}") + + load_tokenizer = task not in no_tokenizer_tasks + load_feature_extractor = task not in no_feature_extractor_tasks + + hub_kwargs = { + "revision": revision, + "token": token, + "trust_remote_code": trust_remote_code, + "_commit_hash": commit_hash, + } + + if isinstance(model, Path): + model = str(model) + + if torch_dtype is not None: + if "torch_dtype" in model_kwargs: + raise ValueError( + 'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those' + " arguments might conflict, use only one.)" + ) + model_kwargs["torch_dtype"] = torch_dtype + + # Load the correct model if possible + # Infer the framework from the model if not already defined + model, model_id = MAPPING_LOADING_FUNC[accelerator](model, task, supported_tasks, model_kwargs, hub_kwargs) + + if load_tokenizer and tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(model_id, **hub_kwargs, **model_kwargs) + if load_feature_extractor and feature_extractor is None: + feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, **hub_kwargs, **model_kwargs) + + return transformers_pipeline( + task, + model=model, + tokenizer=tokenizer, + feature_extractor=feature_extractor, + use_fast=use_fast, + torch_dtype=torch_dtype, + ) diff --git a/tests/ipex/test_pipelines.py b/tests/ipex/test_pipelines.py new file mode 100644 index 0000000000..89a27ab2c8 --- /dev/null +++ b/tests/ipex/test_pipelines.py @@ -0,0 +1,265 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from tempfile import TemporaryDirectory + +import numpy as np +import torch +from parameterized import parameterized +from transformers import AutoTokenizer +from transformers.pipelines import pipeline as transformers_pipeline + +from optimum.intel.ipex.modeling_base import ( + IPEXModelForAudioClassification, + IPEXModelForCausalLM, + IPEXModelForImageClassification, + IPEXModelForMaskedLM, + IPEXModelForQuestionAnswering, + IPEXModelForSequenceClassification, + IPEXModelForTokenClassification, +) +from optimum.intel.pipelines import pipeline as ipex_pipeline + + +MODEL_NAMES = { + "albert": "hf-internal-testing/tiny-random-albert", + "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", + "bert": "hf-internal-testing/tiny-random-bert", + "bart": "hf-internal-testing/tiny-random-bart", + "blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel", + "blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel", + "bloom": "hf-internal-testing/tiny-random-BloomModel", + "convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification", + "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", + "convnext": "hf-internal-testing/tiny-random-convnext", + "distilbert": "hf-internal-testing/tiny-random-distilbert", + "electra": "hf-internal-testing/tiny-random-electra", + "flaubert": "hf-internal-testing/tiny-random-flaubert", + "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", + "gpt2": "hf-internal-testing/tiny-random-gpt2", + "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", + "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", + "gptj": "hf-internal-testing/tiny-random-GPTJModel", + "levit": "hf-internal-testing/tiny-random-LevitModel", + "llama": "fxmarty/tiny-llama-fast-tokenizer", + "llama2": "Jiqing/tiny_random_llama2", + "marian": "sshleifer/tiny-marian-en-de", + "mbart": "hf-internal-testing/tiny-random-mbart", + "mistral": "echarlaix/tiny-random-mistral", + "mobilenet_v1": "google/mobilenet_v1_0.75_192", + "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model", + "mobilevit": "hf-internal-testing/tiny-random-mobilevit", + "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", + "mt5": "stas/mt5-tiny-random", + "opt": "hf-internal-testing/tiny-random-OPTModel", + "phi": "echarlaix/tiny-random-PhiForCausalLM", + "resnet": "hf-internal-testing/tiny-random-resnet", + "roberta": "hf-internal-testing/tiny-random-roberta", + "roformer": "hf-internal-testing/tiny-random-roformer", + "squeezebert": "hf-internal-testing/tiny-random-squeezebert", + "t5": "hf-internal-testing/tiny-random-t5", + "unispeech": "hf-internal-testing/tiny-random-unispeech", + "vit": "hf-internal-testing/tiny-random-vit", + "wav2vec2": "anton-l/wav2vec2-random-tiny-classifier", + "xlm": "hf-internal-testing/tiny-random-xlm", +} + + +class PipelinesIntegrationTest(unittest.TestCase): + COMMON_SUPPORTED_ARCHITECTURES = ( + "albert", + "bert", + "distilbert", + "electra", + "flaubert", + "roberta", + "roformer", + "squeezebert", + "xlm", + ) + TEXT_GENERATION_SUPPORTED_ARCHITECTURES = ( + "bart", + "gpt_bigcode", + "blenderbot", + "blenderbot-small", + "bloom", + "codegen", + "gpt2", + "gpt_neo", + "gpt_neox", + "llama", + "llama2", + "mistral", + "mpt", + "opt", + ) + QUESTION_ANSWERING_SUPPORTED_ARCHITECTURES = ( + "bert", + "distilbert", + "roberta", + ) + AUDIO_CLASSIFICATION_SUPPORTED_ARCHITECTURES = ( + "unispeech", + "wav2vec2", + ) + IMAGE_CLASSIFICATION_SUPPORTED_ARCHITECTURES = ( + "beit", + "mobilenet_v1", + "mobilenet_v2", + "mobilevit", + "resnet", + "vit", + ) + + @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) + def test_token_classification_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + transformers_generator = transformers_pipeline("token-classification", model_id) + ipex_generator = ipex_pipeline("token-classification", model_id, accelerator="ipex") + inputs = "Hello I'm Omar and I live in Zürich." + with torch.inference_mode(): + transformers_output = transformers_generator(inputs) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertEqual(len(transformers_output), len(ipex_output)) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForTokenClassification)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + for i in range(len(transformers_output)): + self.assertAlmostEqual(transformers_output[i]["score"], ipex_output[i]["score"], delta=1e-4) + + @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) + def test_sequence_classification_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + transformers_generator = transformers_pipeline("text-classification", model_id) + ipex_generator = ipex_pipeline("text-classification", model_id, accelerator="ipex") + inputs = "This restaurant is awesome" + with torch.inference_mode(): + transformers_output = transformers_generator(inputs) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSequenceClassification)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + self.assertEqual(transformers_output[0]["label"], ipex_output[0]["label"]) + self.assertAlmostEqual(transformers_output[0]["score"], ipex_output[0]["score"], delta=1e-4) + + @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) + def test_fill_mask_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + inputs = "The Milky Way is a galaxy." + transformers_generator = transformers_pipeline("fill-mask", model_id) + ipex_generator = ipex_pipeline("fill-mask", model_id, accelerator="ipex") + mask_token = transformers_generator.tokenizer.mask_token + inputs = inputs.replace("", mask_token) + with torch.inference_mode(): + transformers_output = transformers_generator(inputs) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertEqual(len(transformers_output), len(ipex_output)) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForMaskedLM)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + for i in range(len(transformers_output)): + self.assertEqual(transformers_output[i]["token"], ipex_output[i]["token"]) + self.assertAlmostEqual(transformers_output[i]["score"], ipex_output[i]["score"], delta=1e-4) + + @parameterized.expand(TEXT_GENERATION_SUPPORTED_ARCHITECTURES) + def test_text_generation_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + transformers_generator = transformers_pipeline("text-generation", model_id) + ipex_generator = ipex_pipeline("text-generation", model_id, accelerator="ipex") + inputs = "Describe a real-world application of AI." + with torch.inference_mode(): + transformers_output = transformers_generator(inputs) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForCausalLM)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"]) + + @parameterized.expand(QUESTION_ANSWERING_SUPPORTED_ARCHITECTURES) + def test_question_answering_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + transformers_generator = transformers_pipeline("question-answering", model_id) + ipex_generator = ipex_pipeline("question-answering", model_id, accelerator="ipex") + question = "How many programming languages does BLOOM support?" + context = "BLOOM has 176 billion parameters and can generate text in 46 languages natural languages and 13 programming languages." + with torch.inference_mode(): + transformers_output = transformers_generator(question=question, context=context) + with torch.inference_mode(): + ipex_output = ipex_generator(question=question, context=context) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForQuestionAnswering)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + self.assertAlmostEqual(transformers_output["score"], ipex_output["score"], delta=1e-4) + self.assertEqual(transformers_output["start"], ipex_output["start"]) + self.assertEqual(transformers_output["end"], ipex_output["end"]) + + @parameterized.expand(AUDIO_CLASSIFICATION_SUPPORTED_ARCHITECTURES) + def test_audio_classification_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + transformers_generator = transformers_pipeline("audio-classification", model_id) + ipex_generator = ipex_pipeline("audio-classification", model_id, accelerator="ipex") + inputs = [np.random.random(16000)] + with torch.inference_mode(): + transformers_output = transformers_generator(inputs) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForAudioClassification)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + self.assertAlmostEqual(transformers_output[0][0]["score"], ipex_output[0][0]["score"], delta=1e-2) + self.assertAlmostEqual(transformers_output[0][1]["score"], ipex_output[0][1]["score"], delta=1e-2) + + @parameterized.expand(IMAGE_CLASSIFICATION_SUPPORTED_ARCHITECTURES) + def test_image_classification_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + transformers_generator = transformers_pipeline("image-classification", model_id) + ipex_generator = ipex_pipeline("image-classification", model_id, accelerator="ipex") + inputs = "http://images.cocodataset.org/val2017/000000039769.jpg" + with torch.inference_mode(): + transformers_output = transformers_generator(inputs) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertEqual(len(transformers_output), len(ipex_output)) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForImageClassification)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + for i in range(len(transformers_output)): + self.assertEqual(transformers_output[i]["label"], ipex_output[i]["label"]) + self.assertAlmostEqual(transformers_output[i]["score"], ipex_output[i]["score"], delta=1e-4) + + @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) + def test_pipeline_load_from_ipex_model(self, model_arch): + model_id = MODEL_NAMES[model_arch] + model = IPEXModelForSequenceClassification.from_pretrained(model_id, export=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + ipex_generator = ipex_pipeline("text-classification", model, tokenizer=tokenizer, accelerator="ipex") + inputs = "This restaurant is awesome" + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSequenceClassification)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + self.assertGreaterEqual(ipex_output[0]["score"], 0.0) + + @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) + def test_pipeline_load_from_jit_model(self, model_arch): + model_id = MODEL_NAMES[model_arch] + model = IPEXModelForSequenceClassification.from_pretrained(model_id, export=True) + save_dir = TemporaryDirectory().name + model.save_pretrained(save_dir) + tokenizer = AutoTokenizer.from_pretrained(model_id) + ipex_generator = ipex_pipeline("text-classification", save_dir, tokenizer=tokenizer, accelerator="ipex") + inputs = "This restaurant is awesome" + with torch.inference_mode(): + ipex_output = ipex_generator(inputs) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSequenceClassification)) + self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) + self.assertGreaterEqual(ipex_output[0]["score"], 0.0)