@@ -151,35 +151,17 @@ def _from_transformers(
151
151
model_id : str ,
152
152
config : PretrainedConfig ,
153
153
use_cache : bool = True ,
154
- use_auth_token : Optional [Union [bool , str ]] = None ,
155
- revision : Optional [str ] = None ,
156
- force_download : bool = False ,
157
- cache_dir : str = HUGGINGFACE_HUB_CACHE ,
158
- subfolder : str = "" ,
159
- local_files_only : bool = False ,
160
- torch_dtype : Optional [Union [str , "torch.dtype" ]] = None ,
161
- trust_remote_code : bool = False ,
154
+ ** model_kwargs ,
162
155
):
163
156
if is_torch_version ("<" , "2.1.0" ):
164
157
raise ImportError ("`torch>=2.0.0` is needed to trace your model" )
165
158
166
159
task = cls .export_feature
167
- model_kwargs = {
168
- "revision" : revision ,
169
- "use_auth_token" : use_auth_token ,
170
- "cache_dir" : cache_dir ,
171
- "subfolder" : subfolder ,
172
- "local_files_only" : local_files_only ,
173
- "force_download" : force_download ,
174
- "torch_dtype" : torch_dtype ,
175
- "trust_remote_code" : trust_remote_code ,
176
- }
177
-
178
160
model = TasksManager .get_model_from_task (task , model_id , ** model_kwargs )
179
161
traced_model = ipex_jit_trace (model , task , use_cache )
180
162
181
163
config .torchscript = True
182
- config .torch_dtype = torch_dtype
164
+ config .torch_dtype = model_kwargs . get ( " torch_dtype" , None )
183
165
184
166
return cls (traced_model , config = config , model_save_dir = model_id , use_cache = use_cache , warmup = False )
185
167
0 commit comments