Skip to content

Commit 8a6a2b9

Browse files
committed
fix model files loading when detected
1 parent 58120be commit 8a6a2b9

File tree

1 file changed

+40
-38
lines changed

1 file changed

+40
-38
lines changed

optimum/onnxruntime/modeling_ort.py

+40-38
Original file line numberDiff line numberDiff line change
@@ -485,50 +485,52 @@ def _from_pretrained(
485485
**kwargs,
486486
) -> "ORTModel":
487487
model_path = Path(model_id)
488+
defaut_file_name = file_name or "model.onnx"
489+
490+
if local_files_only:
491+
object_id = str(model_id).replace("/", "--")
492+
cached_model_dir = os.path.join(cache_dir, f"models--{object_id}")
493+
refs_file = os.path.join(os.path.join(cached_model_dir, "refs"), revision or "main")
494+
with open(refs_file) as f:
495+
_revision = f.read()
496+
model_dir = os.path.join(cached_model_dir, "snapshots", _revision)
497+
else:
498+
model_dir = str(model_id)
488499

489-
if file_name is None:
490-
if local_files_only:
491-
object_id = str(model_id).replace("/", "--")
492-
cached_model_dir = os.path.join(cache_dir, f"models--{object_id}")
493-
refs_file = os.path.join(os.path.join(cached_model_dir, "refs"), revision or "main")
494-
with open(refs_file) as f:
495-
_revision = f.read()
496-
model_dir = os.path.join(cached_model_dir, "snapshots", _revision)
497-
else:
498-
model_dir = str(model_id)
499-
500-
onnx_files = find_files_matching_pattern(
501-
model_dir,
502-
ONNX_FILE_PATTERN,
503-
glob_pattern="**/*.onnx",
504-
subfolder=subfolder,
505-
token=token,
506-
revision=revision,
507-
)
500+
onnx_files = find_files_matching_pattern(
501+
model_dir,
502+
ONNX_FILE_PATTERN,
503+
glob_pattern="**/*.onnx",
504+
subfolder=subfolder,
505+
token=token,
506+
revision=revision,
507+
)
508508

509-
model_path = Path(model_dir)
510-
if len(onnx_files) == 0:
511-
raise FileNotFoundError(f"Could not find any ONNX model file in {model_dir}")
509+
model_path = Path(model_dir)
510+
if len(onnx_files) == 0:
511+
raise FileNotFoundError(f"Could not find any ONNX model file in {model_dir}")
512+
if len(onnx_files) == 1 and file_name and file_name != onnx_files[0].name:
513+
raise FileNotFoundError(f"Trying to load {file_name} but only found {onnx_files[0].name}")
512514

513-
file_name = onnx_files[0].name
514-
subfolder = onnx_files[0].parent
515+
file_name = onnx_files[0].name
516+
subfolder = onnx_files[0].parent
515517

516-
if len(onnx_files) > 1:
517-
for file in onnx_files:
518-
if file.name == "model.onnx":
519-
file_name = file.name
520-
subfolder = file.parent
521-
break
518+
if len(onnx_files) > 1:
519+
for file in onnx_files:
520+
if file.name == defaut_file_name:
521+
file_name = file.name
522+
subfolder = file.parent
523+
break
522524

523-
logger.warning(
524-
f"Too many ONNX model files were found in {' ,'.join(map(str, onnx_files))}. "
525-
"specify which one to load by using the `file_name` and/or the `subfolder` arguments. "
526-
f"Loading the file {file_name} in the subfolder {subfolder}."
527-
)
525+
logger.warning(
526+
f"Too many ONNX model files were found in {' ,'.join(map(str, onnx_files))}. "
527+
"specify which one to load by using the `file_name` and/or the `subfolder` arguments. "
528+
f"Loading the file {file_name} in the subfolder {subfolder}."
529+
)
528530

529-
if model_path.is_dir():
530-
model_path = subfolder
531-
subfolder = ""
531+
if model_path.is_dir():
532+
model_path = subfolder
533+
subfolder = ""
532534

533535
model_cache_path, preprocessors = cls._cached_file(
534536
model_path=model_path,

0 commit comments

Comments
 (0)