Skip to content

Commit a625fdc

Browse files
committed
Infer if model needs to be exported to ONNX
1 parent 4d7ed99 commit a625fdc

File tree

2 files changed

+46
-23
lines changed

2 files changed

+46
-23
lines changed

optimum/onnxruntime/modeling_ort.py

+43-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""ORTModelForXXX classes, allowing to run ONNX Models with ONNX Runtime using the same API as Transformers."""
1515

1616
import logging
17+
import os
1718
import re
1819
import shutil
1920
import warnings
@@ -65,7 +66,7 @@
6566
from ..exporters.onnx import main_export
6667
from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel
6768
from ..onnx.utils import _get_external_data_paths
68-
from ..utils.file_utils import find_files_matching_pattern
69+
from ..utils.file_utils import _find_files_matching_pattern, find_files_matching_pattern
6970
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
7071
from .io_binding import IOBindingHelper, TypeHelper
7172
from .utils import (
@@ -88,6 +89,7 @@
8889
_TOKENIZER_FOR_DOC = "AutoTokenizer"
8990
_FEATURE_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
9091
_PROCESSOR_FOR_DOC = "AutoProcessor"
92+
_FILE_PATTERN = r"^.*\.onnx$"
9193

9294
ONNX_MODEL_END_DOCSTRING = r"""
9395
This model inherits from [`~onnxruntime.modeling_ort.ORTModel`], check its documentation for the generic methods the
@@ -684,6 +686,7 @@ def from_pretrained(
684686
subfolder: str = "",
685687
config: Optional["PretrainedConfig"] = None,
686688
local_files_only: bool = False,
689+
revision: Optional[str] = None,
687690
provider: str = "CPUExecutionProvider",
688691
session_options: Optional[ort.SessionOptions] = None,
689692
provider_options: Optional[Dict[str, Any]] = None,
@@ -731,15 +734,53 @@ def from_pretrained(
731734
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
732735
token = use_auth_token
733736

737+
_export = export
738+
try:
739+
if local_files_only:
740+
object_id = model_id.replace("/", "--")
741+
cached_model_dir = os.path.join(cache_dir, f"models--{object_id}")
742+
refs_file = os.path.join(os.path.join(cached_model_dir, "refs"), revision or "main")
743+
with open(refs_file) as f:
744+
revision = f.read()
745+
model_dir = os.path.join(cached_model_dir, "snapshots", revision)
746+
else:
747+
model_dir = model_id
748+
749+
onnx_files = _find_files_matching_pattern(
750+
model_dir,
751+
pattern=_FILE_PATTERN,
752+
subfolder=subfolder,
753+
token=token,
754+
revision=revision,
755+
)
756+
_export = len(onnx_files) == 0
757+
if _export ^ export:
758+
if export:
759+
logger.warning(
760+
f"The model {model_id} was already converted to ONNX but got `export=True`, the model will be converted to ONNX once again. "
761+
"Don't forget to save the resulting model with `.save_pretrained()`"
762+
)
763+
_export = True
764+
else:
765+
logger.warning(
766+
f"No ONNX files were found for {model_id}, setting `export=True` to convert the model to ONNX. "
767+
"Don't forget to save the resulting model with `.save_pretrained()`"
768+
)
769+
except Exception as exception:
770+
logger.warning(
771+
f"Could not infer whether the model was already converted or not to ONNX, keeping `export={export}`.\n{exception}"
772+
)
773+
734774
return super().from_pretrained(
735775
model_id,
736-
export=export,
776+
export=_export,
737777
force_download=force_download,
738778
token=token,
739779
cache_dir=cache_dir,
740780
subfolder=subfolder,
741781
config=config,
742782
local_files_only=local_files_only,
783+
revision=revision,
743784
provider=provider,
744785
session_options=session_options,
745786
provider_options=provider_options,

optimum/pipelines/pipelines_base.py

+3-21
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747

4848
from ..bettertransformer import BetterTransformer
4949
from ..utils import is_onnxruntime_available, is_transformers_version
50-
from ..utils.file_utils import find_files_matching_pattern
5150

5251

5352
if is_onnxruntime_available():
@@ -242,28 +241,11 @@ def load_ort_pipeline(
242241

243242
if model is None:
244243
model_id = SUPPORTED_TASKS[targeted_task]["default"]
245-
model = SUPPORTED_TASKS[targeted_task]["class"][0].from_pretrained(model_id, export=True)
244+
model = SUPPORTED_TASKS[targeted_task]["class"][0].from_pretrained(model_id)
246245
elif isinstance(model, str):
247-
from ..onnxruntime.modeling_seq2seq import ENCODER_ONNX_FILE_PATTERN, ORTModelForConditionalGeneration
248-
249-
model_id = model
250-
ort_model_class = SUPPORTED_TASKS[targeted_task]["class"][0]
251-
252-
if issubclass(ort_model_class, ORTModelForConditionalGeneration):
253-
pattern = ENCODER_ONNX_FILE_PATTERN
254-
else:
255-
pattern = ".+?.onnx"
256-
257-
onnx_files = find_files_matching_pattern(
258-
model,
259-
pattern,
260-
glob_pattern="**/*.onnx",
261-
subfolder=subfolder,
262-
token=token,
263-
revision=revision,
246+
model = SUPPORTED_TASKS[targeted_task]["class"][0].from_pretrained(
247+
model, revision=revision, subfolder=subfolder, token=token, **model_kwargs
264248
)
265-
export = len(onnx_files) == 0
266-
model = ort_model_class.from_pretrained(model, export=export, **model_kwargs)
267249
elif isinstance(model, ORTModel):
268250
if tokenizer is None and load_tokenizer:
269251
for preprocessor in model.preprocessors:

0 commit comments

Comments
 (0)