Skip to content

Commit 8d1347f

Browse files
authored
Add tf available and version (#2154)
* remove torch version * add tf check * fix
1 parent 605ed7e commit 8d1347f

File tree

6 files changed

+68
-27
lines changed

6 files changed

+68
-27
lines changed

optimum/exporters/onnx/base.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,12 @@
4444
from ...utils import TORCH_MINIMUM_VERSION as GLOBAL_MIN_TORCH_VERSION
4545
from ...utils import TRANSFORMERS_MINIMUM_VERSION as GLOBAL_MIN_TRANSFORMERS_VERSION
4646
from ...utils.doc import add_dynamic_docstring
47-
from ...utils.import_utils import is_onnx_available, is_onnxruntime_available, is_transformers_version
47+
from ...utils.import_utils import (
48+
is_onnx_available,
49+
is_onnxruntime_available,
50+
is_torch_version,
51+
is_transformers_version,
52+
)
4853
from ..base import ExportConfig
4954
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
5055
from .model_patcher import ModelPatcher, Seq2SeqModelPatcher
@@ -386,9 +391,8 @@ def is_torch_support_available(self) -> bool:
386391
`bool`: Whether the installed version of PyTorch is compatible with the model.
387392
"""
388393
if is_torch_available():
389-
from ...utils import torch_version
394+
return is_torch_version(">=", self.MIN_TORCH_VERSION.base_version)
390395

391-
return torch_version >= self.MIN_TORCH_VERSION
392396
return False
393397

394398
@property

optimum/exporters/onnx/convert.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -851,17 +851,16 @@ def export(
851851
)
852852

853853
if is_torch_available() and isinstance(model, nn.Module):
854-
from ...utils import torch_version
854+
from ...utils.import_utils import _torch_version
855855

856856
if not is_torch_onnx_support_available():
857857
raise MinimumVersionError(
858-
f"Unsupported PyTorch version, minimum required is {TORCH_MINIMUM_VERSION}, got: {torch_version}"
858+
f"Unsupported PyTorch version, minimum required is {TORCH_MINIMUM_VERSION}, got: {_torch_version}"
859859
)
860860

861861
if not config.is_torch_support_available:
862862
raise MinimumVersionError(
863-
f"Unsupported PyTorch version for this model. Minimum required is {config.MIN_TORCH_VERSION},"
864-
f" got: {torch.__version__}"
863+
f"Unsupported PyTorch version for this model. Minimum required is {config.MIN_TORCH_VERSION}, got: {_torch_version}"
865864
)
866865

867866
export_output = export_pytorch(

optimum/utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@
4343
is_onnxruntime_available,
4444
is_pydantic_available,
4545
is_sentence_transformers_available,
46+
is_tf_available,
4647
is_timm_available,
48+
is_torch_available,
4749
is_torch_onnx_support_available,
4850
is_torch_version,
4951
is_transformers_available,

optimum/utils/import_utils.py

+52-13
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""Import utilities."""
1515

16-
import importlib.metadata as importlib_metadata
16+
import importlib.metadata
1717
import importlib.util
1818
import inspect
1919
import operator as op
@@ -23,7 +23,6 @@
2323

2424
import numpy as np
2525
from packaging import version
26-
from transformers.utils import is_torch_available
2726

2827

2928
TORCH_MINIMUM_VERSION = version.parse("1.11.0")
@@ -64,14 +63,46 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
6463
_datasets_available = _is_package_available("datasets")
6564
_diffusers_available, _diffusers_version = _is_package_available("diffusers", return_version=True)
6665
_transformers_available, _transformers_version = _is_package_available("transformers", return_version=True)
66+
_torch_available, _torch_version = _is_package_available("torch", return_version=True)
67+
6768
# importlib.metadata.version seem to not be robust with the ONNX Runtime extensions (`onnxruntime-gpu`, etc.)
6869
_onnxruntime_available = _is_package_available("onnxruntime", return_version=False)
6970

70-
7171
# TODO : Remove
72-
torch_version = None
73-
if is_torch_available():
74-
torch_version = version.parse(importlib_metadata.version("torch"))
72+
torch_version = version.parse(importlib.metadata.version("torch")) if _torch_available else None
73+
74+
75+
# Note: _is_package_available("tensorflow") fails for tensorflow-cpu. Please test any changes to the line below
76+
# with tensorflow-cpu to make sure it still works!
77+
_tf_available = importlib.util.find_spec("tensorflow") is not None
78+
_tf_version = None
79+
if _tf_available:
80+
candidates = (
81+
"tensorflow",
82+
"tensorflow-cpu",
83+
"tensorflow-gpu",
84+
"tf-nightly",
85+
"tf-nightly-cpu",
86+
"tf-nightly-gpu",
87+
"tf-nightly-rocm",
88+
"intel-tensorflow",
89+
"intel-tensorflow-avx512",
90+
"tensorflow-rocm",
91+
"tensorflow-macos",
92+
"tensorflow-aarch64",
93+
)
94+
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
95+
for pkg in candidates:
96+
try:
97+
_tf_version = importlib.metadata.version(pkg)
98+
break
99+
except importlib.metadata.PackageNotFoundError:
100+
pass
101+
_tf_available = _tf_version is not None
102+
if _tf_available:
103+
if version.parse(_tf_version) < version.parse("2"):
104+
_tf_available = False
105+
_tf_version = _tf_version or "N/A"
75106

76107

77108
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
@@ -91,7 +122,7 @@ def compare_versions(library_or_version: Union[str, version.Version], operation:
91122
raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
92123
operation = STR_OPERATION_TO_FUNC[operation]
93124
if isinstance(library_or_version, str):
94-
library_or_version = version.parse(importlib_metadata.version(library_or_version))
125+
library_or_version = version.parse(importlib.metadata.version(library_or_version))
95126
return operation(library_or_version, version.parse(requirement_version))
96127

97128

@@ -117,15 +148,15 @@ def is_torch_version(operation: str, reference_version: str):
117148
"""
118149
Compare the current torch version to a given reference with an operation.
119150
"""
120-
if not is_torch_available():
151+
if not _torch_available:
121152
return False
122153

123154
import torch
124155

125156
return compare_versions(version.parse(version.parse(torch.__version__).base_version), operation, reference_version)
126157

127158

128-
_is_torch_onnx_support_available = is_torch_available() and is_torch_version(">=", TORCH_MINIMUM_VERSION.base_version)
159+
_is_torch_onnx_support_available = _torch_available and is_torch_version(">=", TORCH_MINIMUM_VERSION.base_version)
129160

130161

131162
def is_torch_onnx_support_available():
@@ -176,9 +207,17 @@ def is_transformers_available():
176207
return _transformers_available
177208

178209

210+
def is_torch_available():
211+
return _torch_available
212+
213+
214+
def is_tf_available():
215+
return _tf_available
216+
217+
179218
def is_auto_gptq_available():
180219
if _auto_gptq_available:
181-
v = version.parse(importlib_metadata.version("auto_gptq"))
220+
v = version.parse(importlib.metadata.version("auto_gptq"))
182221
if v >= AUTOGPTQ_MINIMUM_VERSION:
183222
return True
184223
else:
@@ -189,7 +228,7 @@ def is_auto_gptq_available():
189228

190229
def is_gptqmodel_available():
191230
if _gptqmodel_available:
192-
v = version.parse(importlib_metadata.version("gptqmodel"))
231+
v = version.parse(importlib.metadata.version("gptqmodel"))
193232
if v >= GPTQMODEL_MINIMUM_VERSION:
194233
return True
195234
else:
@@ -260,10 +299,10 @@ def check_if_torch_greater(target_version: str) -> bool:
260299
Returns:
261300
bool: whether the check is True or not.
262301
"""
263-
if not is_torch_available():
302+
if not _torch_available:
264303
return False
265304

266-
return torch_version >= version.parse(target_version)
305+
return version.parse(_torch_version) >= version.parse(target_version)
267306

268307

269308
@contextmanager

optimum/utils/input_generators.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
from typing import Any, List, Optional, Tuple, Union
2121

2222
import numpy as np
23-
from transformers.utils import is_tf_available, is_torch_available
2423

25-
from ..utils import is_diffusers_version, is_transformers_version
24+
from ..utils import is_diffusers_version, is_tf_available, is_torch_available, is_transformers_version
2625
from .normalized_config import (
2726
NormalizedConfig,
2827
NormalizedEncoderDecoderConfig,

tests/exporters/onnx/test_onnx_export.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -207,20 +207,18 @@ def _onnx_export(
207207
model.config.pad_token_id = 0
208208

209209
if is_torch_available():
210-
from optimum.utils import torch_version
210+
from optimum.utils.import_utils import _torch_version, _transformers_version
211211

212212
if not onnx_config.is_transformers_support_available:
213-
import transformers
214-
215213
pytest.skip(
216214
"Skipping due to incompatible Transformers version. Minimum required is"
217-
f" {onnx_config.MIN_TRANSFORMERS_VERSION}, got: {transformers.__version__}"
215+
f" {onnx_config.MIN_TRANSFORMERS_VERSION}, got: {_transformers_version}"
218216
)
219217

220218
if not onnx_config.is_torch_support_available:
221219
pytest.skip(
222220
"Skipping due to incompatible PyTorch version. Minimum required is"
223-
f" {onnx_config.MIN_TORCH_VERSION}, got: {torch_version}"
221+
f" {onnx_config.MIN_TORCH_VERSION}, got: {_torch_version}"
224222
)
225223

226224
atol = onnx_config.ATOL_FOR_VALIDATION

0 commit comments

Comments
 (0)