@@ -485,50 +485,52 @@ def _from_pretrained(
485
485
** kwargs ,
486
486
) -> "ORTModel" :
487
487
model_path = Path (model_id )
488
+ defaut_file_name = file_name or "model.onnx"
489
+
490
+ if local_files_only :
491
+ object_id = str (model_id ).replace ("/" , "--" )
492
+ cached_model_dir = os .path .join (cache_dir , f"models--{ object_id } " )
493
+ refs_file = os .path .join (os .path .join (cached_model_dir , "refs" ), revision or "main" )
494
+ with open (refs_file ) as f :
495
+ _revision = f .read ()
496
+ model_dir = os .path .join (cached_model_dir , "snapshots" , _revision )
497
+ else :
498
+ model_dir = str (model_id )
488
499
489
- if file_name is None :
490
- if local_files_only :
491
- object_id = str (model_id ).replace ("/" , "--" )
492
- cached_model_dir = os .path .join (cache_dir , f"models--{ object_id } " )
493
- refs_file = os .path .join (os .path .join (cached_model_dir , "refs" ), revision or "main" )
494
- with open (refs_file ) as f :
495
- _revision = f .read ()
496
- model_dir = os .path .join (cached_model_dir , "snapshots" , _revision )
497
- else :
498
- model_dir = str (model_id )
499
-
500
- onnx_files = find_files_matching_pattern (
501
- model_dir ,
502
- ONNX_FILE_PATTERN ,
503
- glob_pattern = "**/*.onnx" ,
504
- subfolder = subfolder ,
505
- token = token ,
506
- revision = revision ,
507
- )
500
+ onnx_files = find_files_matching_pattern (
501
+ model_dir ,
502
+ ONNX_FILE_PATTERN ,
503
+ glob_pattern = "**/*.onnx" ,
504
+ subfolder = subfolder ,
505
+ token = token ,
506
+ revision = revision ,
507
+ )
508
508
509
- model_path = Path (model_dir )
510
- if len (onnx_files ) == 0 :
511
- raise FileNotFoundError (f"Could not find any ONNX model file in { model_dir } " )
509
+ model_path = Path (model_dir )
510
+ if len (onnx_files ) == 0 :
511
+ raise FileNotFoundError (f"Could not find any ONNX model file in { model_dir } " )
512
+ if len (onnx_files ) == 1 and file_name and file_name != onnx_files [0 ].name :
513
+ raise FileNotFoundError (f"Trying to load { file_name } but only found { onnx_files [0 ].name } " )
512
514
513
- file_name = onnx_files [0 ].name
514
- subfolder = onnx_files [0 ].parent
515
+ file_name = onnx_files [0 ].name
516
+ subfolder = onnx_files [0 ].parent
515
517
516
- if len (onnx_files ) > 1 :
517
- for file in onnx_files :
518
- if file .name == "model.onnx" :
519
- file_name = file .name
520
- subfolder = file .parent
521
- break
518
+ if len (onnx_files ) > 1 :
519
+ for file in onnx_files :
520
+ if file .name == defaut_file_name :
521
+ file_name = file .name
522
+ subfolder = file .parent
523
+ break
522
524
523
- logger .warning (
524
- f"Too many ONNX model files were found in { ' ,' .join (map (str , onnx_files ))} . "
525
- "specify which one to load by using the `file_name` and/or the `subfolder` arguments. "
526
- f"Loading the file { file_name } in the subfolder { subfolder } ."
527
- )
525
+ logger .warning (
526
+ f"Too many ONNX model files were found in { ' ,' .join (map (str , onnx_files ))} . "
527
+ "specify which one to load by using the `file_name` and/or the `subfolder` arguments. "
528
+ f"Loading the file { file_name } in the subfolder { subfolder } ."
529
+ )
528
530
529
- if model_path .is_dir ():
530
- model_path = subfolder
531
- subfolder = ""
531
+ if model_path .is_dir ():
532
+ model_path = subfolder
533
+ subfolder = ""
532
534
533
535
model_cache_path , preprocessors = cls ._cached_file (
534
536
model_path = model_path ,
0 commit comments