Skip to content

Commit 568aa35

Browse files
authored
Fix use_auth_token with ORTModel (#1740)
fix use_auth_token
1 parent 7e08a82 commit 568aa35

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

optimum/exporters/tasks.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,7 @@ def get_model_files(
13781378
model_name_or_path: Union[str, Path],
13791379
subfolder: str = "",
13801380
cache_dir: str = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE,
1381+
use_auth_token: Optional[str] = None,
13811382
):
13821383
request_exception = None
13831384
full_model_path = Path(model_name_or_path) / subfolder
@@ -1391,7 +1392,9 @@ def get_model_files(
13911392
try:
13921393
if not isinstance(model_name_or_path, str):
13931394
model_name_or_path = str(model_name_or_path)
1394-
all_files = huggingface_hub.list_repo_files(model_name_or_path, repo_type="model")
1395+
all_files = huggingface_hub.list_repo_files(
1396+
model_name_or_path, repo_type="model", token=use_auth_token
1397+
)
13951398
if subfolder != "":
13961399
all_files = [file[len(subfolder) + 1 :] for file in all_files if file.startswith(subfolder)]
13971400
except RequestsConnectionError as e: # Hub not accessible
@@ -1672,6 +1675,7 @@ def infer_library_from_model(
16721675
revision: Optional[str] = None,
16731676
cache_dir: str = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE,
16741677
library_name: Optional[str] = None,
1678+
use_auth_token: Optional[str] = None,
16751679
):
16761680
"""
16771681
Infers the library from the model repo.
@@ -1689,13 +1693,17 @@ def infer_library_from_model(
16891693
Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used.
16901694
library_name (`Optional[str]`, *optional*):
16911695
The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers".
1696+
use_auth_token (`Optional[str]`, defaults to `None`):
1697+
The token to use as HTTP bearer authorization for remote files.
16921698
Returns:
16931699
`str`: The library name automatically detected from the model repo.
16941700
"""
16951701
if library_name is not None:
16961702
return library_name
16971703

1698-
all_files, _ = TasksManager.get_model_files(model_name_or_path, subfolder, cache_dir)
1704+
all_files, _ = TasksManager.get_model_files(
1705+
model_name_or_path, subfolder, cache_dir, use_auth_token=use_auth_token
1706+
)
16991707

17001708
if "model_index.json" in all_files:
17011709
library_name = "diffusers"
@@ -1710,6 +1718,7 @@ def infer_library_from_model(
17101718
"subfolder": subfolder,
17111719
"revision": revision,
17121720
"cache_dir": cache_dir,
1721+
"use_auth_token": use_auth_token,
17131722
}
17141723
config_dict, kwargs = PretrainedConfig.get_config_dict(model_name_or_path, **kwargs)
17151724
model_config = PretrainedConfig.from_dict(config_dict, **kwargs)

optimum/modeling_base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,9 @@ def from_pretrained(
346346
)
347347
model_id, revision = model_id.split("@")
348348

349-
library_name = TasksManager.infer_library_from_model(model_id, subfolder, revision, cache_dir)
349+
library_name = TasksManager.infer_library_from_model(
350+
model_id, subfolder, revision, cache_dir, use_auth_token=use_auth_token
351+
)
350352

351353
if library_name == "timm":
352354
config = PretrainedConfig.from_pretrained(model_id, subfolder, revision)

tests/onnxruntime/test_modeling.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -937,9 +937,12 @@ def test_stable_diffusion_model_on_rocm_ep_str(self):
937937
self.assertEqual(model.vae_encoder.session.get_providers()[0], "ROCMExecutionProvider")
938938
self.assertListEqual(model.providers, ["ROCMExecutionProvider", "CPUExecutionProvider"])
939939

940-
@require_hf_token
941940
def test_load_model_from_hub_private(self):
942-
model = ORTModel.from_pretrained(self.ONNX_MODEL_ID, use_auth_token=os.environ.get("HF_AUTH_TOKEN", None))
941+
subprocess.run("huggingface-cli logout", shell=True)
942+
# Read token of fxmartyclone (dummy user).
943+
token = "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"
944+
945+
model = ORTModelForCustomTasks.from_pretrained("fxmartyclone/tiny-onnx-private-2", use_auth_token=token)
943946
self.assertIsInstance(model.model, onnxruntime.InferenceSession)
944947
self.assertIsInstance(model.config, PretrainedConfig)
945948

0 commit comments

Comments
 (0)