15
15
from pathlib import Path
16
16
from typing import TYPE_CHECKING , Any , Dict , Optional , Union
17
17
18
+ import torch
18
19
from transformers import AutoConfig , AutoFeatureExtractor , AutoTokenizer
19
20
from transformers import pipeline as transformers_pipeline
20
21
from transformers .feature_extraction_utils import PreTrainedFeatureExtractor
31
32
from transformers .tokenization_utils import PreTrainedTokenizer
32
33
from transformers .utils import (
33
34
is_ipex_available ,
34
- is_torch_available ,
35
35
logging ,
36
36
)
37
37
@@ -98,13 +98,9 @@ def load_ipex_model(
98
98
model ,
99
99
targeted_task ,
100
100
SUPPORTED_TASKS ,
101
- subfolder : str = "" ,
102
- token : Optional [Union [bool , str ]] = None ,
103
- revision : str = "main" ,
104
101
model_kwargs : Optional [Dict [str , Any ]] = None ,
105
- ** kwargs ,
102
+ hub_kwargs : Optional [ Dict [ str , Any ]] = None ,
106
103
):
107
- export = kwargs .pop ("export" , True )
108
104
if model_kwargs is None :
109
105
model_kwargs = {}
110
106
@@ -118,15 +114,13 @@ def load_ipex_model(
118
114
try :
119
115
config = AutoConfig .from_pretrained (model )
120
116
torchscript = getattr (config , "torchscript" , None )
121
- export = False if torchscript else export
117
+ export = False if torchscript else True
122
118
except RuntimeError :
123
- logger .warning (
124
- "config file not found, please pass `export` to decide whether we should export this model. `export` defaullt to True"
125
- )
126
-
119
+ logger .warning ("We will use IPEXModel with export=True to export the model" )
120
+ export = True
127
121
model = ipex_model_class .from_pretrained (model , export = export , ** model_kwargs , ** hub_kwargs )
128
122
elif isinstance (model , IPEXModel ):
129
- model_id = None
123
+ model_id = getattr ( model . config , "name_or_path" , None )
130
124
else :
131
125
raise ValueError (
132
126
f"""Model { model } is not supported. Please provide a valid model name or path or a IPEXModel.
@@ -141,7 +135,6 @@ def load_ipex_model(
141
135
}
142
136
143
137
144
-
145
138
if TYPE_CHECKING :
146
139
from transformers .modeling_utils import PreTrainedModel
147
140
from transformers .tokenization_utils_fast import PreTrainedTokenizerFast
@@ -160,8 +153,9 @@ def pipeline(
160
153
accelerator : Optional [str ] = "ort" ,
161
154
revision : Optional [str ] = None ,
162
155
trust_remote_code : Optional [bool ] = None ,
163
- * model_kwargs ,
164
- ** kwargs ,
156
+ torch_dtype : Optional [Union [str , torch .dtype ]] = None ,
157
+ commit_hash : Optional [str ] = None ,
158
+ ** model_kwargs ,
165
159
) -> Pipeline :
166
160
"""
167
161
Utility factory method to build a [`Pipeline`].
@@ -201,9 +195,6 @@ def pipeline(
201
195
model_kwargs (`Dict[str, Any]`, *optional*):
202
196
Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
203
197
**model_kwargs)` function.
204
- kwargs (`Dict[str, Any]`, *optional*):
205
- Additional keyword arguments passed along to the specific pipeline init (see the documentation for the
206
- corresponding pipeline class for possible values).
207
198
208
199
Returns:
209
200
[`Pipeline`]: A suitable pipeline for the task.
@@ -235,7 +226,9 @@ def pipeline(
235
226
)
236
227
237
228
if accelerator not in MAPPING_LOADING_FUNC :
238
- raise ValueError (f'Accelerator { accelerator } is not supported. Supported accelerator is { ", " .join (MAPPING_LOADING_FUNC )} .' )
229
+ raise ValueError (
230
+ f'Accelerator { accelerator } is not supported. Supported accelerator is { ", " .join (MAPPING_LOADING_FUNC )} .'
231
+ )
239
232
240
233
if accelerator == "ipex" :
241
234
if task not in list (IPEX_SUPPORTED_TASKS .keys ()):
@@ -260,12 +253,10 @@ def pipeline(
260
253
load_tokenizer = task not in no_tokenizer_tasks
261
254
load_feature_extractor = task not in no_feature_extractor_tasks
262
255
263
- commit_hash = kwargs .pop ("_commit_hash" , None )
264
-
265
256
hub_kwargs = {
266
- "revision" : kwargs . pop ( " revision" , None ) ,
267
- "token" : kwargs . pop ( "use_auth_token" , None ) ,
268
- "trust_remote_code" : kwargs . pop ( " trust_remote_code" , None ) ,
257
+ "revision" : revision ,
258
+ "token" : token ,
259
+ "trust_remote_code" : trust_remote_code ,
269
260
"_commit_hash" : commit_hash ,
270
261
}
271
262
@@ -282,22 +273,18 @@ def pipeline(
282
273
283
274
# Load the correct model if possible
284
275
# Infer the framework from the model if not already defined
285
- model , model_id = MAPPING_LOADING_FUNC [accelerator ](
286
- model , task , supported_tasks , model_kwargs , hub_kwargs , ** kwargs
287
- )
276
+ model , model_id = MAPPING_LOADING_FUNC [accelerator ](model , task , supported_tasks , model_kwargs , hub_kwargs )
288
277
289
278
if load_tokenizer and tokenizer is None :
290
279
tokenizer = AutoTokenizer .from_pretrained (model_id , ** hub_kwargs , ** model_kwargs )
291
280
if load_feature_extractor and feature_extractor is None :
292
281
feature_extractor = AutoFeatureExtractor .from_pretrained (model_id , ** hub_kwargs , ** model_kwargs )
293
282
294
-
295
283
return transformers_pipeline (
296
284
task ,
297
285
model = model ,
298
286
tokenizer = tokenizer ,
299
287
feature_extractor = feature_extractor ,
300
288
use_fast = use_fast ,
301
289
torch_dtype = torch_dtype ,
302
- ** kwargs ,
303
290
)
0 commit comments