14
14
"""Classes handling causal-lm related architectures in ONNX Runtime."""
15
15
16
16
import logging
17
+ import os
18
+ import re
17
19
from pathlib import Path
18
20
from tempfile import TemporaryDirectory
19
21
from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Union
32
34
from ..exporters .onnx import MODEL_TYPES_REQUIRING_POSITION_IDS , main_export
33
35
from ..onnx .utils import check_model_uses_external_data
34
36
from ..utils import NormalizedConfigManager , is_transformers_version
35
- from ..utils .modeling_utils import MODEL_TO_PATCH_FOR_PAST
37
+ from ..utils .file_utils import find_files_matching_pattern
36
38
from ..utils .save_utils import maybe_save_preprocessors
37
- from .constants import DECODER_MERGED_ONNX_FILE_PATTERN , DECODER_ONNX_FILE_PATTERN , DECODER_WITH_PAST_ONNX_FILE_PATTERN
39
+ from .constants import (
40
+ DECODER_MERGED_ONNX_FILE_PATTERN ,
41
+ DECODER_ONNX_FILE_PATTERN ,
42
+ DECODER_WITH_PAST_ONNX_FILE_PATTERN ,
43
+ ONNX_FILE_PATTERN ,
44
+ )
38
45
from .modeling_ort import ONNX_MODEL_END_DOCSTRING , ORTModel
39
46
from .models .bloom import bloom_convert_to_bloom_cache , bloom_convert_to_standard_cache
40
- from .utils import ONNX_DECODER_NAME , ONNX_DECODER_WITH_PAST_NAME , ONNX_WEIGHTS_NAME
47
+ from .utils import ONNX_WEIGHTS_NAME
41
48
42
49
43
50
if TYPE_CHECKING :
@@ -400,7 +407,6 @@ def _from_pretrained(
400
407
** kwargs ,
401
408
) -> "ORTModelForCausalLM" :
402
409
generation_config = kwargs .pop ("generation_config" , None )
403
- model_path = Path (model_id )
404
410
405
411
# We do not implement the logic for use_cache=False, use_merged=True
406
412
if use_cache is False :
@@ -411,68 +417,69 @@ def _from_pretrained(
411
417
)
412
418
use_merged = False
413
419
414
- decoder_name = "decoder_file_name" if use_cache else "decoder_with_past_file_name"
415
- decoder_file_name = kwargs .pop (decoder_name , None )
420
+ onnx_files = find_files_matching_pattern (
421
+ model_id ,
422
+ ONNX_FILE_PATTERN ,
423
+ glob_pattern = "**/*.onnx" ,
424
+ subfolder = subfolder ,
425
+ token = token ,
426
+ revision = revision ,
427
+ )
428
+
429
+ if len (onnx_files ) == 0 :
430
+ raise FileNotFoundError (f"Could not find any ONNX model file in { model_id } " )
416
431
417
- if decoder_file_name is not None :
418
- logger .warning (f"The `{ decoder_name } ` argument is deprecated, please use `file_name` instead." )
419
- file_name = file_name or decoder_file_name
432
+ if len (onnx_files ) == 1 :
433
+ subfolder = onnx_files [0 ].parent
434
+ _file_name = onnx_files [0 ].name
435
+ if file_name and file_name != _file_name :
436
+ raise FileNotFoundError (f"Trying to load { file_name } but only found { _file_name } " )
437
+ file_name = _file_name
420
438
421
- if file_name is None :
422
- decoder_path = None
423
- # We use `is not False` here to include two cases: use_merged = None (in which case we auto-detect it),
424
- # and use_merged = True (explicitely specified by the user)
439
+ else :
440
+ model_files = []
441
+ # Check first for merged models and then for decoder / decoder_with_past models
425
442
if use_merged is not False :
426
- try :
427
- decoder_path = ORTModelForCausalLM .infer_onnx_filename (
428
- model_id ,
429
- [DECODER_MERGED_ONNX_FILE_PATTERN ],
430
- argument_name = None ,
431
- subfolder = subfolder ,
432
- token = token ,
433
- revision = revision ,
434
- )
435
- use_merged = True
436
- file_name = decoder_path .name
437
- except FileNotFoundError as e :
438
- if use_merged is True :
439
- raise FileNotFoundError (
440
- "The parameter `use_merged=True` was passed to ORTModelForCausalLM.from_pretrained()"
441
- " but no ONNX file for a merged decoder could be found in"
442
- f" { str (Path (model_id , subfolder ))} , with the error: { e } "
443
- )
444
- use_merged = False
443
+ model_files = [p for p in onnx_files if re .search (DECODER_MERGED_ONNX_FILE_PATTERN , str (p ))]
444
+ use_merged = len (model_files ) != 0
445
445
446
446
if use_merged is False :
447
447
pattern = DECODER_WITH_PAST_ONNX_FILE_PATTERN if use_cache else DECODER_ONNX_FILE_PATTERN
448
- # exclude decoder file for first iteration
449
- decoder_path = ORTModelForCausalLM .infer_onnx_filename (
450
- model_id ,
451
- [r"^((?!decoder).)*.onnx" , pattern ],
452
- argument_name = None ,
453
- subfolder = subfolder ,
454
- token = token ,
455
- revision = revision ,
456
- )
457
- file_name = decoder_path .name
448
+ model_files = [p for p in onnx_files if re .search (pattern , str (p ))]
458
449
459
- if file_name == ONNX_DECODER_WITH_PAST_NAME and config .model_type in MODEL_TO_PATCH_FOR_PAST :
460
- raise ValueError (
461
- f"ONNX Runtime inference using { ONNX_DECODER_WITH_PAST_NAME } has been deprecated for { config .model_type } architecture. Please re-export your model with optimum>=1.14.0 or set use_cache=False. For details about the deprecation, please refer to https://github.com/huggingface/optimum/releases/tag/v1.14.0."
450
+ # if file_name is specified we don't filter legacy models
451
+ if not model_files or file_name :
452
+ model_files = onnx_files
453
+ else :
454
+ logger .warning (
455
+ f"Legacy models found in { model_files } will be loaded. "
456
+ "Legacy models will be deprecated in the next version of optimum, please re-export your model"
462
457
)
458
+ _file_name = model_files [0 ].name
459
+ subfolder = model_files [0 ].parent
463
460
464
- regular_file_names = []
465
- for name in [ONNX_WEIGHTS_NAME , ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME ]:
466
- regular_file_names += ORTModelForCausalLM ._generate_regular_names_for_filename (name )
461
+ defaut_file_name = file_name or "model.onnx"
462
+ for file in model_files :
463
+ if file .name == defaut_file_name :
464
+ _file_name = file .name
465
+ subfolder = file .parent
466
+ break
467
467
468
- if file_name not in regular_file_names :
468
+ file_name = _file_name
469
+
470
+ if len (model_files ) > 1 :
469
471
logger .warning (
470
- f"The ONNX file { file_name } is not a regular name used in optimum.onnxruntime that are { regular_file_names } , the "
471
- f"{ cls .__name__ } might not behave as expected."
472
+ f"Too many ONNX model files were found in { ' ,' .join (map (str , model_files ))} . "
473
+ "specify which one to load by using the `file_name` and/or the `subfolder` arguments. "
474
+ f"Loading the file { file_name } in the subfolder { subfolder } ."
472
475
)
473
476
477
+ if os .path .isdir (model_id ):
478
+ model_id = subfolder
479
+ subfolder = ""
480
+
474
481
model_cache_path , preprocessors = cls ._cached_file (
475
- model_path = model_path ,
482
+ model_path = model_id ,
476
483
token = token ,
477
484
revision = revision ,
478
485
force_download = force_download ,
@@ -481,7 +488,7 @@ def _from_pretrained(
481
488
subfolder = subfolder ,
482
489
local_files_only = local_files_only ,
483
490
)
484
- new_model_save_dir = model_cache_path .parent
491
+ new_model_save_dir = Path ( model_cache_path ) .parent
485
492
486
493
# model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it
487
494
# instead of the path only.
0 commit comments