@@ -90,7 +90,7 @@ def ipex_jit_trace(model, task, use_cache):
90
90
91
91
if _is_patched_with_ipex (model , task ):
92
92
model = _patch_model (model )
93
- # Todo : integerate in prepare_jit_inputs.
93
+ # TODO : integerate in prepare_jit_inputs.
94
94
sample_inputs = get_dummy_input (model , return_dict = True )
95
95
# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
96
96
_enable_tpp ()
@@ -151,7 +151,7 @@ def __init__(
151
151
logger .warning ("The model has been exported already." )
152
152
else :
153
153
config = model .config if config is None else config
154
- use_cache = kwargs .get ("use_cache" , None )
154
+ use_cache = kwargs .get ("use_cache" , True )
155
155
model = ipex_jit_trace (model , self .export_feature , use_cache )
156
156
config .torchscript = True
157
157
@@ -162,11 +162,13 @@ def __init__(
162
162
self .model_save_dir = model_save_dir
163
163
self ._is_ipex_exported = _is_patched_with_ipex (model , self .export_feature )
164
164
165
- self .input_names = (
166
- {inputs .debugName ().split ("." )[0 ] for inputs in model .graph .inputs () if inputs .debugName () != "self" }
167
- if isinstance (model , torch .jit .RecursiveScriptModule )
168
- else inspect .signature (model .forward ).parameters
169
- )
165
+ if isinstance (model , torch .jit .RecursiveScriptModule ):
166
+ self .input_names = {
167
+ inputs .debugName ().split ("." )[0 ] for inputs in model .graph .inputs () if inputs .debugName () != "self"
168
+ }
169
+ else :
170
+ self .input_names = set (inspect .signature (model .forward ).parameters )
171
+
170
172
# Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
171
173
# a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
172
174
AutoConfig .register (self .base_model_prefix , AutoConfig )
@@ -184,7 +186,6 @@ def _from_pretrained(
184
186
cls ,
185
187
model_id : Union [str , Path ],
186
188
config : PretrainedConfig ,
187
- use_cache : bool = True ,
188
189
use_auth_token : Optional [Union [bool , str ]] = None ,
189
190
token : Optional [Union [bool , str ]] = None ,
190
191
revision : Optional [str ] = None ,
@@ -194,7 +195,6 @@ def _from_pretrained(
194
195
local_files_only : bool = False ,
195
196
torch_dtype : Optional [Union [str , "torch.dtype" ]] = None ,
196
197
trust_remote_code : bool = False ,
197
- _commit_hash : str = None ,
198
198
file_name : Optional [str ] = WEIGHTS_NAME ,
199
199
** kwargs ,
200
200
):
@@ -209,6 +209,17 @@ def _from_pretrained(
209
209
)
210
210
token = use_auth_token
211
211
212
+ commit_hash = kwargs .pop ("_commit_hash" , None )
213
+
214
+ model_kwargs = {
215
+ "revision" : revision ,
216
+ "token" : token ,
217
+ "cache_dir" : cache_dir ,
218
+ "subfolder" : subfolder ,
219
+ "local_files_only" : local_files_only ,
220
+ "force_download" : force_download ,
221
+ }
222
+
212
223
if not getattr (config , "torchscript" , False ):
213
224
logger .warning ("Detect torchscript is false. Convert to torchscript model!" )
214
225
@@ -217,44 +228,30 @@ def _from_pretrained(
217
228
218
229
task = cls .export_feature
219
230
config .torch_dtype = torch_dtype
220
- model_kwargs = {
221
- "revision" : revision ,
222
- "token" : token ,
223
- "cache_dir" : cache_dir ,
224
- "subfolder" : subfolder ,
225
- "local_files_only" : local_files_only ,
226
- "force_download" : force_download ,
227
- "torch_dtype" : torch_dtype ,
228
- "trust_remote_code" : trust_remote_code ,
229
- "_commit_hash" : _commit_hash ,
230
- }
231
-
232
- model = TasksManager .get_model_from_task (task , model_id , ** model_kwargs )
231
+ model = TasksManager .get_model_from_task (
232
+ task ,
233
+ model_id ,
234
+ trust_remote_code = trust_remote_code ,
235
+ torch_dtype = torch_dtype ,
236
+ _commit_hash = commit_hash ,
237
+ ** model_kwargs ,
238
+ )
233
239
234
- return cls (model , config = config , export = True , use_cache = use_cache , ** kwargs )
240
+ return cls (model , config = config , export = True , ** kwargs )
235
241
236
242
# Load the model from local directory
237
243
if os .path .isdir (model_id ):
238
244
model_cache_path = os .path .join (model_id , file_name )
239
245
model_save_dir = model_id
240
246
# Download the model from the hub
241
247
else :
242
- model_cache_path = hf_hub_download (
243
- repo_id = model_id ,
244
- filename = file_name ,
245
- token = token ,
246
- revision = revision ,
247
- cache_dir = cache_dir ,
248
- force_download = force_download ,
249
- local_files_only = local_files_only ,
250
- subfolder = subfolder ,
251
- )
248
+ model_cache_path = hf_hub_download (repo_id = model_id , filename = file_name , ** model_kwargs )
252
249
model_save_dir = Path (model_cache_path ).parent
253
250
254
251
model = torch .jit .load (model_cache_path )
255
252
torch .jit .freeze (model .eval ())
256
253
257
- return cls (model , config = config , model_save_dir = model_save_dir , use_cache = use_cache , ** kwargs )
254
+ return cls (model , config = config , model_save_dir = model_save_dir , ** kwargs )
258
255
259
256
def _save_pretrained (self , save_directory : Union [str , Path ]):
260
257
output_path = os .path .join (save_directory , WEIGHTS_NAME )
0 commit comments