Skip to content

Commit d9c8f9f

Browse files
Add IPEX pipeline (#501)
* define optimum-intel pipeline * add tests and readme * fix pipelines example * fix readme codestyle * add _load_model in pipeline * update pipeline for optimum intel * update tests * remove readme * Update optimum/intel/pipelines/__init__.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * fix pipelines * add all supported tasks testing * add hub_kwargs and model_kwargs on tokenizer and feature_extractor * add hub_kwargs and default pipeline tests * fix _from_transformers args * rm default pipeline test * Update optimum/intel/pipelines/pipeline_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Update optimum/intel/pipelines/pipeline_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Update optimum/intel/pipelines/pipeline_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Update optimum/intel/pipelines/pipeline_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Update optimum/intel/pipelines/pipeline_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Update optimum/intel/pipelines/pipeline_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Update optimum/intel/pipelines/pipeline_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Update optimum/intel/pipelines/pipeline_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * fix comments * Update optimum/exporters/openvino/model_patcher.py * Update optimum/intel/ipex/modeling_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Update optimum/intel/pipelines/pipeline_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Update optimum/intel/pipelines/pipeline_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Update optimum/intel/pipelines/pipeline_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * fix style --------- Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent d021798 commit d9c8f9f

File tree

5 files changed

+576
-0
lines changed

5 files changed

+576
-0
lines changed

optimum/intel/ipex/inference.py

+4
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def __init__(
9797
jit (`boolean = False`, *optional*):
9898
Enable jit to accelerate inference speed
9999
"""
100+
logger.warning(
101+
"`inference_mode` is deprecated and will be removed in v1.18.0. Use `pipeline` to load and export your model to TorchScript instead."
102+
)
103+
100104
if not is_ipex_available():
101105
raise ImportError(IPEX_NOT_AVAILABLE_ERROR_MSG)
102106

optimum/intel/ipex/modeling_base.py

+2
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def _from_transformers(
161161
local_files_only: bool = False,
162162
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
163163
trust_remote_code: bool = False,
164+
_commit_hash: str = None,
164165
):
165166
if use_auth_token is not None:
166167
warnings.warn(
@@ -186,6 +187,7 @@ def _from_transformers(
186187
"force_download": force_download,
187188
"torch_dtype": torch_dtype,
188189
"trust_remote_code": trust_remote_code,
190+
"_commit_hash": _commit_hash,
189191
}
190192

191193
model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)

optimum/intel/pipelines/__init__.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .pipeline_base import pipeline
+290
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pathlib import Path
16+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
17+
18+
import torch
19+
from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer
20+
from transformers import pipeline as transformers_pipeline
21+
from transformers.feature_extraction_utils import PreTrainedFeatureExtractor
22+
from transformers.pipelines import (
23+
AudioClassificationPipeline,
24+
FillMaskPipeline,
25+
ImageClassificationPipeline,
26+
QuestionAnsweringPipeline,
27+
TextClassificationPipeline,
28+
TextGenerationPipeline,
29+
TokenClassificationPipeline,
30+
)
31+
from transformers.pipelines.base import Pipeline
32+
from transformers.tokenization_utils import PreTrainedTokenizer
33+
from transformers.utils import logging
34+
35+
from optimum.intel.utils import is_ipex_available
36+
37+
38+
if is_ipex_available():
39+
from ..ipex.modeling_base import (
40+
IPEXModel,
41+
IPEXModelForAudioClassification,
42+
IPEXModelForCausalLM,
43+
IPEXModelForImageClassification,
44+
IPEXModelForMaskedLM,
45+
IPEXModelForQuestionAnswering,
46+
IPEXModelForSequenceClassification,
47+
IPEXModelForTokenClassification,
48+
)
49+
50+
IPEX_SUPPORTED_TASKS = {
51+
"text-generation": {
52+
"impl": TextGenerationPipeline,
53+
"class": (IPEXModelForCausalLM,),
54+
"default": "gpt2",
55+
"type": "text",
56+
},
57+
"fill-mask": {
58+
"impl": FillMaskPipeline,
59+
"class": (IPEXModelForMaskedLM,),
60+
"default": "bert-base-cased",
61+
"type": "text",
62+
},
63+
"question-answering": {
64+
"impl": QuestionAnsweringPipeline,
65+
"class": (IPEXModelForQuestionAnswering,),
66+
"default": "distilbert-base-cased-distilled-squad",
67+
"type": "text",
68+
},
69+
"image-classification": {
70+
"impl": ImageClassificationPipeline,
71+
"class": (IPEXModelForImageClassification,),
72+
"default": "google/vit-base-patch16-224",
73+
"type": "image",
74+
},
75+
"text-classification": {
76+
"impl": TextClassificationPipeline,
77+
"class": (IPEXModelForSequenceClassification,),
78+
"default": "distilbert-base-uncased-finetuned-sst-2-english",
79+
"type": "text",
80+
},
81+
"token-classification": {
82+
"impl": TokenClassificationPipeline,
83+
"class": (IPEXModelForTokenClassification,),
84+
"default": "dbmdz/bert-large-cased-finetuned-conll03-english",
85+
"type": "text",
86+
},
87+
"audio-classification": {
88+
"impl": AudioClassificationPipeline,
89+
"class": (IPEXModelForAudioClassification,),
90+
"default": "superb/hubert-base-superb-ks",
91+
"type": "audio",
92+
},
93+
}
94+
else:
95+
IPEX_SUPPORTED_TASKS = {}
96+
97+
98+
def load_ipex_model(
99+
model,
100+
targeted_task,
101+
SUPPORTED_TASKS,
102+
model_kwargs: Optional[Dict[str, Any]] = None,
103+
hub_kwargs: Optional[Dict[str, Any]] = None,
104+
):
105+
if model_kwargs is None:
106+
model_kwargs = {}
107+
108+
ipex_model_class = SUPPORTED_TASKS[targeted_task]["class"][0]
109+
110+
if model is None:
111+
model_id = SUPPORTED_TASKS[targeted_task]["default"]
112+
model = ipex_model_class.from_pretrained(model_id, export=True, **model_kwargs, **hub_kwargs)
113+
elif isinstance(model, str):
114+
model_id = model
115+
try:
116+
config = AutoConfig.from_pretrained(model)
117+
export = not getattr(config, "torchscript", False)
118+
except RuntimeError:
119+
logger.warning("We will use IPEXModel with export=True to export the model")
120+
export = True
121+
model = ipex_model_class.from_pretrained(model, export=export, **model_kwargs, **hub_kwargs)
122+
elif isinstance(model, IPEXModel):
123+
model_id = getattr(model.config, "name_or_path", None)
124+
else:
125+
raise ValueError(
126+
f"""Model {model} is not supported. Please provide a valid model name or path or a IPEXModel.
127+
You can also provide non model then a default one will be used"""
128+
)
129+
130+
return model, model_id
131+
132+
133+
MAPPING_LOADING_FUNC = {
134+
"ipex": load_ipex_model,
135+
}
136+
137+
138+
if TYPE_CHECKING:
139+
from transformers.modeling_utils import PreTrainedModel
140+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
141+
142+
143+
logger = logging.get_logger(__name__)
144+
145+
146+
def pipeline(
147+
task: str = None,
148+
model: Optional[Union[str, "PreTrainedModel"]] = None,
149+
tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None,
150+
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
151+
use_fast: bool = True,
152+
token: Optional[Union[str, bool]] = None,
153+
accelerator: Optional[str] = "ort",
154+
revision: Optional[str] = None,
155+
trust_remote_code: Optional[bool] = None,
156+
torch_dtype: Optional[Union[str, torch.dtype]] = None,
157+
commit_hash: Optional[str] = None,
158+
**model_kwargs,
159+
) -> Pipeline:
160+
"""
161+
Utility factory method to build a [`Pipeline`].
162+
163+
Pipelines are made of:
164+
165+
- A [tokenizer](tokenizer) in charge of mapping raw textual input to token.
166+
- A [model](model) to make predictions from the inputs.
167+
- Some (optional) post processing for enhancing model's output.
168+
169+
Args:
170+
task (`str`):
171+
The task defining which pipeline will be returned. Currently accepted tasks are:
172+
173+
- `"text-generation"`: will return a [`TextGenerationPipeline`]:.
174+
175+
model (`str` or [`PreTrainedModel`], *optional*):
176+
The model that will be used by the pipeline to make predictions. This can be a model identifier or an
177+
actual instance of a pretrained model inheriting from [`PreTrainedModel`] (for PyTorch).
178+
179+
If not provided, the default for the `task` will be loaded.
180+
tokenizer (`str` or [`PreTrainedTokenizer`], *optional*):
181+
The tokenizer that will be used by the pipeline to encode data for the model. This can be a model
182+
identifier or an actual pretrained tokenizer inheriting from [`PreTrainedTokenizer`].
183+
184+
If not provided, the default tokenizer for the given `model` will be loaded (if it is a string). If `model`
185+
is not specified or not a string, then the default tokenizer for `config` is loaded (if it is a string).
186+
However, if `config` is also not given or not a string, then the default tokenizer for the given `task`
187+
will be loaded.
188+
accelerator (`str`, *optional*, defaults to `"ipex"`):
189+
The optimization backends, choose from ["ipex", "inc", "openvino"].
190+
use_fast (`bool`, *optional*, defaults to `True`):
191+
Whether or not to use a Fast tokenizer if possible (a [`PreTrainedTokenizerFast`]).
192+
torch_dtype (`str` or `torch.dtype`, *optional*):
193+
Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model
194+
(`torch.float16`, `torch.bfloat16`, ... or `"auto"`).
195+
model_kwargs (`Dict[str, Any]`, *optional*):
196+
Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
197+
**model_kwargs)` function.
198+
199+
Returns:
200+
[`Pipeline`]: A suitable pipeline for the task.
201+
202+
Examples:
203+
204+
```python
205+
>>> import torch
206+
>>> from optimum.intel.pipelines import pipeline
207+
208+
>>> pipe = pipeline('text-generation', 'gpt2', torch_dtype=torch.bfloat16)
209+
>>> pipe("Describe a real-world application of AI in sustainable energy.")
210+
```"""
211+
if model_kwargs is None:
212+
model_kwargs = {}
213+
214+
if task is None and model is None:
215+
raise RuntimeError(
216+
"Impossible to instantiate a pipeline without either a task or a model "
217+
"being specified. "
218+
"Please provide a task class or a model"
219+
)
220+
221+
if model is None and tokenizer is not None:
222+
raise RuntimeError(
223+
"Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided tokenizer"
224+
" may not be compatible with the default model. Please provide a PreTrainedModel class or a"
225+
" path/identifier to a pretrained model when providing tokenizer."
226+
)
227+
228+
if accelerator not in MAPPING_LOADING_FUNC:
229+
raise ValueError(
230+
f'Accelerator {accelerator} is not supported. Supported accelerator is {", ".join(MAPPING_LOADING_FUNC)}.'
231+
)
232+
233+
if accelerator == "ipex":
234+
if task not in list(IPEX_SUPPORTED_TASKS.keys()):
235+
raise ValueError(
236+
f"Task {task} is not supported for the IPEX pipeline. Supported tasks are { list(IPEX_SUPPORTED_TASKS.keys())}"
237+
)
238+
239+
supported_tasks = IPEX_SUPPORTED_TASKS if accelerator == "ipex" else None
240+
241+
no_feature_extractor_tasks = set()
242+
no_tokenizer_tasks = set()
243+
for _task, values in supported_tasks.items():
244+
if values["type"] == "text":
245+
no_feature_extractor_tasks.add(_task)
246+
elif values["type"] in {"image", "video"}:
247+
no_tokenizer_tasks.add(_task)
248+
elif values["type"] in {"audio"}:
249+
no_tokenizer_tasks.add(_task)
250+
elif values["type"] not in ["multimodal", "audio", "video"]:
251+
raise ValueError(f"SUPPORTED_TASK {_task} contains invalid type {values['type']}")
252+
253+
load_tokenizer = task not in no_tokenizer_tasks
254+
load_feature_extractor = task not in no_feature_extractor_tasks
255+
256+
hub_kwargs = {
257+
"revision": revision,
258+
"token": token,
259+
"trust_remote_code": trust_remote_code,
260+
"_commit_hash": commit_hash,
261+
}
262+
263+
if isinstance(model, Path):
264+
model = str(model)
265+
266+
if torch_dtype is not None:
267+
if "torch_dtype" in model_kwargs:
268+
raise ValueError(
269+
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
270+
" arguments might conflict, use only one.)"
271+
)
272+
model_kwargs["torch_dtype"] = torch_dtype
273+
274+
# Load the correct model if possible
275+
# Infer the framework from the model if not already defined
276+
model, model_id = MAPPING_LOADING_FUNC[accelerator](model, task, supported_tasks, model_kwargs, hub_kwargs)
277+
278+
if load_tokenizer and tokenizer is None:
279+
tokenizer = AutoTokenizer.from_pretrained(model_id, **hub_kwargs, **model_kwargs)
280+
if load_feature_extractor and feature_extractor is None:
281+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, **hub_kwargs, **model_kwargs)
282+
283+
return transformers_pipeline(
284+
task,
285+
model=model,
286+
tokenizer=tokenizer,
287+
feature_extractor=feature_extractor,
288+
use_fast=use_fast,
289+
torch_dtype=torch_dtype,
290+
)

0 commit comments

Comments
 (0)