Skip to content

Commit 6335599

Browse files
authored
Infer whether a model needs to be exported to ONNX or not (#2181)
* add files matching patterns * rename * Infer if model needs to be exported to ONNX * adapt test * add test for export diffusers model * adapt test * fix for local files * set subfolder for local dir * force export in tests * check for all available models * add tests * fix model files loading when detected * fix warning message * refacto test * add warning when filename ignored * fix style * fix for windows
1 parent 512d5c6 commit 6335599

File tree

9 files changed

+229
-219
lines changed

9 files changed

+229
-219
lines changed

.github/workflows/test_onnxruntime.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,4 @@ jobs:
6464
run: |
6565
pytest tests/onnxruntime -m "not run_in_series" --durations=0 -vvvv -n auto
6666
env:
67-
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
67+
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}

optimum/onnxruntime/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
DECODER_ONNX_FILE_PATTERN = r"(.*)?decoder((?!(with_past|merged)).)*?\.onnx"
1717
DECODER_WITH_PAST_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?with_past(.*)?\.onnx"
1818
DECODER_MERGED_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?merged(.*)?\.onnx"
19+
ONNX_FILE_PATTERN = r".*\.onnx$"

optimum/onnxruntime/modeling_decoder.py

+60-53
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
"""Classes handling causal-lm related architectures in ONNX Runtime."""
1515

1616
import logging
17+
import os
18+
import re
1719
from pathlib import Path
1820
from tempfile import TemporaryDirectory
1921
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
@@ -32,12 +34,17 @@
3234
from ..exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS, main_export
3335
from ..onnx.utils import check_model_uses_external_data
3436
from ..utils import NormalizedConfigManager, is_transformers_version
35-
from ..utils.modeling_utils import MODEL_TO_PATCH_FOR_PAST
37+
from ..utils.file_utils import find_files_matching_pattern
3638
from ..utils.save_utils import maybe_save_preprocessors
37-
from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN
39+
from .constants import (
40+
DECODER_MERGED_ONNX_FILE_PATTERN,
41+
DECODER_ONNX_FILE_PATTERN,
42+
DECODER_WITH_PAST_ONNX_FILE_PATTERN,
43+
ONNX_FILE_PATTERN,
44+
)
3845
from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel
3946
from .models.bloom import bloom_convert_to_bloom_cache, bloom_convert_to_standard_cache
40-
from .utils import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_WEIGHTS_NAME
47+
from .utils import ONNX_WEIGHTS_NAME
4148

4249

4350
if TYPE_CHECKING:
@@ -400,7 +407,6 @@ def _from_pretrained(
400407
**kwargs,
401408
) -> "ORTModelForCausalLM":
402409
generation_config = kwargs.pop("generation_config", None)
403-
model_path = Path(model_id)
404410

405411
# We do not implement the logic for use_cache=False, use_merged=True
406412
if use_cache is False:
@@ -411,68 +417,69 @@ def _from_pretrained(
411417
)
412418
use_merged = False
413419

414-
decoder_name = "decoder_file_name" if use_cache else "decoder_with_past_file_name"
415-
decoder_file_name = kwargs.pop(decoder_name, None)
420+
onnx_files = find_files_matching_pattern(
421+
model_id,
422+
ONNX_FILE_PATTERN,
423+
glob_pattern="**/*.onnx",
424+
subfolder=subfolder,
425+
token=token,
426+
revision=revision,
427+
)
428+
429+
if len(onnx_files) == 0:
430+
raise FileNotFoundError(f"Could not find any ONNX model file in {model_id}")
416431

417-
if decoder_file_name is not None:
418-
logger.warning(f"The `{decoder_name}` argument is deprecated, please use `file_name` instead.")
419-
file_name = file_name or decoder_file_name
432+
if len(onnx_files) == 1:
433+
subfolder = onnx_files[0].parent
434+
_file_name = onnx_files[0].name
435+
if file_name and file_name != _file_name:
436+
raise FileNotFoundError(f"Trying to load {file_name} but only found {_file_name}")
437+
file_name = _file_name
420438

421-
if file_name is None:
422-
decoder_path = None
423-
# We use `is not False` here to include two cases: use_merged = None (in which case we auto-detect it),
424-
# and use_merged = True (explicitely specified by the user)
439+
else:
440+
model_files = []
441+
# Check first for merged models and then for decoder / decoder_with_past models
425442
if use_merged is not False:
426-
try:
427-
decoder_path = ORTModelForCausalLM.infer_onnx_filename(
428-
model_id,
429-
[DECODER_MERGED_ONNX_FILE_PATTERN],
430-
argument_name=None,
431-
subfolder=subfolder,
432-
token=token,
433-
revision=revision,
434-
)
435-
use_merged = True
436-
file_name = decoder_path.name
437-
except FileNotFoundError as e:
438-
if use_merged is True:
439-
raise FileNotFoundError(
440-
"The parameter `use_merged=True` was passed to ORTModelForCausalLM.from_pretrained()"
441-
" but no ONNX file for a merged decoder could be found in"
442-
f" {str(Path(model_id, subfolder))}, with the error: {e}"
443-
)
444-
use_merged = False
443+
model_files = [p for p in onnx_files if re.search(DECODER_MERGED_ONNX_FILE_PATTERN, str(p))]
444+
use_merged = len(model_files) != 0
445445

446446
if use_merged is False:
447447
pattern = DECODER_WITH_PAST_ONNX_FILE_PATTERN if use_cache else DECODER_ONNX_FILE_PATTERN
448-
# exclude decoder file for first iteration
449-
decoder_path = ORTModelForCausalLM.infer_onnx_filename(
450-
model_id,
451-
[r"^((?!decoder).)*.onnx", pattern],
452-
argument_name=None,
453-
subfolder=subfolder,
454-
token=token,
455-
revision=revision,
456-
)
457-
file_name = decoder_path.name
448+
model_files = [p for p in onnx_files if re.search(pattern, str(p))]
458449

459-
if file_name == ONNX_DECODER_WITH_PAST_NAME and config.model_type in MODEL_TO_PATCH_FOR_PAST:
460-
raise ValueError(
461-
f"ONNX Runtime inference using {ONNX_DECODER_WITH_PAST_NAME} has been deprecated for {config.model_type} architecture. Please re-export your model with optimum>=1.14.0 or set use_cache=False. For details about the deprecation, please refer to https://github.com/huggingface/optimum/releases/tag/v1.14.0."
450+
# if file_name is specified we don't filter legacy models
451+
if not model_files or file_name:
452+
model_files = onnx_files
453+
else:
454+
logger.warning(
455+
f"Legacy models found in {model_files} will be loaded. "
456+
"Legacy models will be deprecated in the next version of optimum, please re-export your model"
462457
)
458+
_file_name = model_files[0].name
459+
subfolder = model_files[0].parent
463460

464-
regular_file_names = []
465-
for name in [ONNX_WEIGHTS_NAME, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME]:
466-
regular_file_names += ORTModelForCausalLM._generate_regular_names_for_filename(name)
461+
defaut_file_name = file_name or "model.onnx"
462+
for file in model_files:
463+
if file.name == defaut_file_name:
464+
_file_name = file.name
465+
subfolder = file.parent
466+
break
467467

468-
if file_name not in regular_file_names:
468+
file_name = _file_name
469+
470+
if len(model_files) > 1:
469471
logger.warning(
470-
f"The ONNX file {file_name} is not a regular name used in optimum.onnxruntime that are {regular_file_names}, the "
471-
f"{cls.__name__} might not behave as expected."
472+
f"Too many ONNX model files were found in {' ,'.join(map(str, model_files))}. "
473+
"specify which one to load by using the `file_name` and/or the `subfolder` arguments. "
474+
f"Loading the file {file_name} in the subfolder {subfolder}."
472475
)
473476

477+
if os.path.isdir(model_id):
478+
model_id = subfolder
479+
subfolder = ""
480+
474481
model_cache_path, preprocessors = cls._cached_file(
475-
model_path=model_path,
482+
model_path=model_id,
476483
token=token,
477484
revision=revision,
478485
force_download=force_download,
@@ -481,7 +488,7 @@ def _from_pretrained(
481488
subfolder=subfolder,
482489
local_files_only=local_files_only,
483490
)
484-
new_model_save_dir = model_cache_path.parent
491+
new_model_save_dir = Path(model_cache_path).parent
485492

486493
# model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it
487494
# instead of the path only.

0 commit comments

Comments
 (0)