Skip to content

Commit 5013fe7

Browse files
committed
fix _from_transformers args
1 parent 6056612 commit 5013fe7

File tree

1 file changed

+2
-20
lines changed

1 file changed

+2
-20
lines changed

optimum/intel/ipex/modeling_base.py

+2-20
Original file line numberDiff line numberDiff line change
@@ -151,35 +151,17 @@ def _from_transformers(
151151
model_id: str,
152152
config: PretrainedConfig,
153153
use_cache: bool = True,
154-
use_auth_token: Optional[Union[bool, str]] = None,
155-
revision: Optional[str] = None,
156-
force_download: bool = False,
157-
cache_dir: str = HUGGINGFACE_HUB_CACHE,
158-
subfolder: str = "",
159-
local_files_only: bool = False,
160-
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
161-
trust_remote_code: bool = False,
154+
**model_kwargs,
162155
):
163156
if is_torch_version("<", "2.1.0"):
164157
raise ImportError("`torch>=2.0.0` is needed to trace your model")
165158

166159
task = cls.export_feature
167-
model_kwargs = {
168-
"revision": revision,
169-
"use_auth_token": use_auth_token,
170-
"cache_dir": cache_dir,
171-
"subfolder": subfolder,
172-
"local_files_only": local_files_only,
173-
"force_download": force_download,
174-
"torch_dtype": torch_dtype,
175-
"trust_remote_code": trust_remote_code,
176-
}
177-
178160
model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
179161
traced_model = ipex_jit_trace(model, task, use_cache)
180162

181163
config.torchscript = True
182-
config.torch_dtype = torch_dtype
164+
config.torch_dtype = model_kwargs.get("torch_dtype", None)
183165

184166
return cls(traced_model, config=config, model_save_dir=model_id, use_cache=use_cache, warmup=False)
185167

0 commit comments

Comments
 (0)