Skip to content

Commit bf2ae08

Browse files
committed
fix comments
1 parent 4effaa4 commit bf2ae08

File tree

3 files changed

+18
-30
lines changed

3 files changed

+18
-30
lines changed

optimum/intel/ipex/modeling_base.py

+2
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def _from_transformers(
161161
local_files_only: bool = False,
162162
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
163163
trust_remote_code: bool = False,
164+
_commit_hash: str = None,
164165
):
165166
if use_auth_token is not None:
166167
warnings.warn(
@@ -186,6 +187,7 @@ def _from_transformers(
186187
"force_download": force_download,
187188
"torch_dtype": torch_dtype,
188189
"trust_remote_code": trust_remote_code,
190+
"_commit_hash": _commit_hash,
189191
}
190192

191193
model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)

optimum/intel/pipelines/pipeline_base.py

+16-29
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pathlib import Path
1616
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
1717

18+
import torch
1819
from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer
1920
from transformers import pipeline as transformers_pipeline
2021
from transformers.feature_extraction_utils import PreTrainedFeatureExtractor
@@ -31,7 +32,6 @@
3132
from transformers.tokenization_utils import PreTrainedTokenizer
3233
from transformers.utils import (
3334
is_ipex_available,
34-
is_torch_available,
3535
logging,
3636
)
3737

@@ -98,13 +98,9 @@ def load_ipex_model(
9898
model,
9999
targeted_task,
100100
SUPPORTED_TASKS,
101-
subfolder: str = "",
102-
token: Optional[Union[bool, str]] = None,
103-
revision: str = "main",
104101
model_kwargs: Optional[Dict[str, Any]] = None,
105-
**kwargs,
102+
hub_kwargs: Optional[Dict[str, Any]] = None,
106103
):
107-
export = kwargs.pop("export", True)
108104
if model_kwargs is None:
109105
model_kwargs = {}
110106

@@ -118,15 +114,13 @@ def load_ipex_model(
118114
try:
119115
config = AutoConfig.from_pretrained(model)
120116
torchscript = getattr(config, "torchscript", None)
121-
export = False if torchscript else export
117+
export = False if torchscript else True
122118
except RuntimeError:
123-
logger.warning(
124-
"config file not found, please pass `export` to decide whether we should export this model. `export` defaullt to True"
125-
)
126-
119+
logger.warning("We will use IPEXModel with export=True to export the model")
120+
export = True
127121
model = ipex_model_class.from_pretrained(model, export=export, **model_kwargs, **hub_kwargs)
128122
elif isinstance(model, IPEXModel):
129-
model_id = None
123+
model_id = getattr(model.config, "name_or_path", None)
130124
else:
131125
raise ValueError(
132126
f"""Model {model} is not supported. Please provide a valid model name or path or a IPEXModel.
@@ -141,7 +135,6 @@ def load_ipex_model(
141135
}
142136

143137

144-
145138
if TYPE_CHECKING:
146139
from transformers.modeling_utils import PreTrainedModel
147140
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
@@ -160,8 +153,9 @@ def pipeline(
160153
accelerator: Optional[str] = "ort",
161154
revision: Optional[str] = None,
162155
trust_remote_code: Optional[bool] = None,
163-
*model_kwargs,
164-
**kwargs,
156+
torch_dtype: Optional[Union[str, torch.dtype]] = None,
157+
commit_hash: Optional[str] = None,
158+
**model_kwargs,
165159
) -> Pipeline:
166160
"""
167161
Utility factory method to build a [`Pipeline`].
@@ -201,9 +195,6 @@ def pipeline(
201195
model_kwargs (`Dict[str, Any]`, *optional*):
202196
Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
203197
**model_kwargs)` function.
204-
kwargs (`Dict[str, Any]`, *optional*):
205-
Additional keyword arguments passed along to the specific pipeline init (see the documentation for the
206-
corresponding pipeline class for possible values).
207198
208199
Returns:
209200
[`Pipeline`]: A suitable pipeline for the task.
@@ -235,7 +226,9 @@ def pipeline(
235226
)
236227

237228
if accelerator not in MAPPING_LOADING_FUNC:
238-
raise ValueError(f'Accelerator {accelerator} is not supported. Supported accelerator is {", ".join(MAPPING_LOADING_FUNC)}.')
229+
raise ValueError(
230+
f'Accelerator {accelerator} is not supported. Supported accelerator is {", ".join(MAPPING_LOADING_FUNC)}.'
231+
)
239232

240233
if accelerator == "ipex":
241234
if task not in list(IPEX_SUPPORTED_TASKS.keys()):
@@ -260,12 +253,10 @@ def pipeline(
260253
load_tokenizer = task not in no_tokenizer_tasks
261254
load_feature_extractor = task not in no_feature_extractor_tasks
262255

263-
commit_hash = kwargs.pop("_commit_hash", None)
264-
265256
hub_kwargs = {
266-
"revision": kwargs.pop("revision", None),
267-
"token": kwargs.pop("use_auth_token", None),
268-
"trust_remote_code": kwargs.pop("trust_remote_code", None),
257+
"revision": revision,
258+
"token": token,
259+
"trust_remote_code": trust_remote_code,
269260
"_commit_hash": commit_hash,
270261
}
271262

@@ -282,22 +273,18 @@ def pipeline(
282273

283274
# Load the correct model if possible
284275
# Infer the framework from the model if not already defined
285-
model, model_id = MAPPING_LOADING_FUNC[accelerator](
286-
model, task, supported_tasks, model_kwargs, hub_kwargs, **kwargs
287-
)
276+
model, model_id = MAPPING_LOADING_FUNC[accelerator](model, task, supported_tasks, model_kwargs, hub_kwargs)
288277

289278
if load_tokenizer and tokenizer is None:
290279
tokenizer = AutoTokenizer.from_pretrained(model_id, **hub_kwargs, **model_kwargs)
291280
if load_feature_extractor and feature_extractor is None:
292281
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, **hub_kwargs, **model_kwargs)
293282

294-
295283
return transformers_pipeline(
296284
task,
297285
model=model,
298286
tokenizer=tokenizer,
299287
feature_extractor=feature_extractor,
300288
use_fast=use_fast,
301289
torch_dtype=torch_dtype,
302-
**kwargs,
303290
)

tests/ipex/test_pipelines.py

-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from transformers.pipelines import pipeline as transformers_pipeline
2323

2424
from optimum.intel.ipex.modeling_base import (
25-
IPEXModel,
2625
IPEXModelForAudioClassification,
2726
IPEXModelForCausalLM,
2827
IPEXModelForImageClassification,

0 commit comments

Comments
 (0)