|
14 | 14 | """ORTModelForXXX classes, allowing to run ONNX Models with ONNX Runtime using the same API as Transformers."""
|
15 | 15 |
|
16 | 16 | import logging
|
| 17 | +import os |
17 | 18 | import re
|
18 | 19 | import shutil
|
19 | 20 | import warnings
|
|
65 | 66 | from ..exporters.onnx import main_export
|
66 | 67 | from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel
|
67 | 68 | from ..onnx.utils import _get_external_data_paths
|
68 |
| -from ..utils.file_utils import find_files_matching_pattern |
| 69 | +from ..utils.file_utils import _find_files_matching_pattern, find_files_matching_pattern |
69 | 70 | from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
|
70 | 71 | from .io_binding import IOBindingHelper, TypeHelper
|
71 | 72 | from .utils import (
|
|
88 | 89 | _TOKENIZER_FOR_DOC = "AutoTokenizer"
|
89 | 90 | _FEATURE_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
|
90 | 91 | _PROCESSOR_FOR_DOC = "AutoProcessor"
|
| 92 | +_FILE_PATTERN = r"^.*\.onnx$" |
91 | 93 |
|
92 | 94 | ONNX_MODEL_END_DOCSTRING = r"""
|
93 | 95 | This model inherits from [`~onnxruntime.modeling_ort.ORTModel`], check its documentation for the generic methods the
|
@@ -684,6 +686,7 @@ def from_pretrained(
|
684 | 686 | subfolder: str = "",
|
685 | 687 | config: Optional["PretrainedConfig"] = None,
|
686 | 688 | local_files_only: bool = False,
|
| 689 | + revision: Optional[str] = None, |
687 | 690 | provider: str = "CPUExecutionProvider",
|
688 | 691 | session_options: Optional[ort.SessionOptions] = None,
|
689 | 692 | provider_options: Optional[Dict[str, Any]] = None,
|
@@ -731,15 +734,53 @@ def from_pretrained(
|
731 | 734 | raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
|
732 | 735 | token = use_auth_token
|
733 | 736 |
|
| 737 | + _export = export |
| 738 | + try: |
| 739 | + if local_files_only: |
| 740 | + object_id = model_id.replace("/", "--") |
| 741 | + cached_model_dir = os.path.join(cache_dir, f"models--{object_id}") |
| 742 | + refs_file = os.path.join(os.path.join(cached_model_dir, "refs"), revision or "main") |
| 743 | + with open(refs_file) as f: |
| 744 | + revision = f.read() |
| 745 | + model_dir = os.path.join(cached_model_dir, "snapshots", revision) |
| 746 | + else: |
| 747 | + model_dir = model_id |
| 748 | + |
| 749 | + onnx_files = _find_files_matching_pattern( |
| 750 | + model_dir, |
| 751 | + pattern=_FILE_PATTERN, |
| 752 | + subfolder=subfolder, |
| 753 | + token=token, |
| 754 | + revision=revision, |
| 755 | + ) |
| 756 | + _export = len(onnx_files) == 0 |
| 757 | + if _export ^ export: |
| 758 | + if export: |
| 759 | + logger.warning( |
| 760 | + f"The model {model_id} was already converted to ONNX but got `export=True`, the model will be converted to ONNX once again. " |
| 761 | + "Don't forget to save the resulting model with `.save_pretrained()`" |
| 762 | + ) |
| 763 | + _export = True |
| 764 | + else: |
| 765 | + logger.warning( |
| 766 | + f"No ONNX files were found for {model_id}, setting `export=True` to convert the model to ONNX. " |
| 767 | + "Don't forget to save the resulting model with `.save_pretrained()`" |
| 768 | + ) |
| 769 | + except Exception as exception: |
| 770 | + logger.warning( |
| 771 | + f"Could not infer whether the model was already converted or not to ONNX, keeping `export={export}`.\n{exception}" |
| 772 | + ) |
| 773 | + |
734 | 774 | return super().from_pretrained(
|
735 | 775 | model_id,
|
736 |
| - export=export, |
| 776 | + export=_export, |
737 | 777 | force_download=force_download,
|
738 | 778 | token=token,
|
739 | 779 | cache_dir=cache_dir,
|
740 | 780 | subfolder=subfolder,
|
741 | 781 | config=config,
|
742 | 782 | local_files_only=local_files_only,
|
| 783 | + revision=revision, |
743 | 784 | provider=provider,
|
744 | 785 | session_options=session_options,
|
745 | 786 | provider_options=provider_options,
|
|
0 commit comments