13
13
# limitations under the License.
14
14
"""Import utilities."""
15
15
16
- import importlib .metadata as importlib_metadata
16
+ import importlib .metadata
17
17
import importlib .util
18
18
import inspect
19
19
import operator as op
23
23
24
24
import numpy as np
25
25
from packaging import version
26
- from transformers .utils import is_torch_available
27
26
28
27
29
28
TORCH_MINIMUM_VERSION = version .parse ("1.11.0" )
@@ -64,14 +63,46 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
64
63
_datasets_available = _is_package_available ("datasets" )
65
64
_diffusers_available , _diffusers_version = _is_package_available ("diffusers" , return_version = True )
66
65
_transformers_available , _transformers_version = _is_package_available ("transformers" , return_version = True )
66
+ _torch_available , _torch_version = _is_package_available ("torch" , return_version = True )
67
+
67
68
# importlib.metadata.version seem to not be robust with the ONNX Runtime extensions (`onnxruntime-gpu`, etc.)
68
69
_onnxruntime_available = _is_package_available ("onnxruntime" , return_version = False )
69
70
70
-
71
71
# TODO : Remove
72
- torch_version = None
73
- if is_torch_available ():
74
- torch_version = version .parse (importlib_metadata .version ("torch" ))
72
+ torch_version = version .parse (importlib .metadata .version ("torch" )) if _torch_available else None
73
+
74
+
75
+ # Note: _is_package_available("tensorflow") fails for tensorflow-cpu. Please test any changes to the line below
76
+ # with tensorflow-cpu to make sure it still works!
77
+ _tf_available = importlib .util .find_spec ("tensorflow" ) is not None
78
+ _tf_version = None
79
+ if _tf_available :
80
+ candidates = (
81
+ "tensorflow" ,
82
+ "tensorflow-cpu" ,
83
+ "tensorflow-gpu" ,
84
+ "tf-nightly" ,
85
+ "tf-nightly-cpu" ,
86
+ "tf-nightly-gpu" ,
87
+ "tf-nightly-rocm" ,
88
+ "intel-tensorflow" ,
89
+ "intel-tensorflow-avx512" ,
90
+ "tensorflow-rocm" ,
91
+ "tensorflow-macos" ,
92
+ "tensorflow-aarch64" ,
93
+ )
94
+ # For the metadata, we have to look for both tensorflow and tensorflow-cpu
95
+ for pkg in candidates :
96
+ try :
97
+ _tf_version = importlib .metadata .version (pkg )
98
+ break
99
+ except importlib .metadata .PackageNotFoundError :
100
+ pass
101
+ _tf_available = _tf_version is not None
102
+ if _tf_available :
103
+ if version .parse (_tf_version ) < version .parse ("2" ):
104
+ _tf_available = False
105
+ _tf_version = _tf_version or "N/A"
75
106
76
107
77
108
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
@@ -91,7 +122,7 @@ def compare_versions(library_or_version: Union[str, version.Version], operation:
91
122
raise ValueError (f"`operation` must be one of { list (STR_OPERATION_TO_FUNC .keys ())} , received { operation } " )
92
123
operation = STR_OPERATION_TO_FUNC [operation ]
93
124
if isinstance (library_or_version , str ):
94
- library_or_version = version .parse (importlib_metadata .version (library_or_version ))
125
+ library_or_version = version .parse (importlib . metadata .version (library_or_version ))
95
126
return operation (library_or_version , version .parse (requirement_version ))
96
127
97
128
@@ -117,15 +148,15 @@ def is_torch_version(operation: str, reference_version: str):
117
148
"""
118
149
Compare the current torch version to a given reference with an operation.
119
150
"""
120
- if not is_torch_available () :
151
+ if not _torch_available :
121
152
return False
122
153
123
154
import torch
124
155
125
156
return compare_versions (version .parse (version .parse (torch .__version__ ).base_version ), operation , reference_version )
126
157
127
158
128
- _is_torch_onnx_support_available = is_torch_available () and is_torch_version (">=" , TORCH_MINIMUM_VERSION .base_version )
159
+ _is_torch_onnx_support_available = _torch_available and is_torch_version (">=" , TORCH_MINIMUM_VERSION .base_version )
129
160
130
161
131
162
def is_torch_onnx_support_available ():
@@ -176,9 +207,17 @@ def is_transformers_available():
176
207
return _transformers_available
177
208
178
209
210
+ def is_torch_available ():
211
+ return _torch_available
212
+
213
+
214
+ def is_tf_available ():
215
+ return _tf_available
216
+
217
+
179
218
def is_auto_gptq_available ():
180
219
if _auto_gptq_available :
181
- v = version .parse (importlib_metadata .version ("auto_gptq" ))
220
+ v = version .parse (importlib . metadata .version ("auto_gptq" ))
182
221
if v >= AUTOGPTQ_MINIMUM_VERSION :
183
222
return True
184
223
else :
@@ -189,7 +228,7 @@ def is_auto_gptq_available():
189
228
190
229
def is_gptqmodel_available ():
191
230
if _gptqmodel_available :
192
- v = version .parse (importlib_metadata .version ("gptqmodel" ))
231
+ v = version .parse (importlib . metadata .version ("gptqmodel" ))
193
232
if v >= GPTQMODEL_MINIMUM_VERSION :
194
233
return True
195
234
else :
@@ -260,10 +299,10 @@ def check_if_torch_greater(target_version: str) -> bool:
260
299
Returns:
261
300
bool: whether the check is True or not.
262
301
"""
263
- if not is_torch_available () :
302
+ if not _torch_available :
264
303
return False
265
304
266
- return torch_version >= version .parse (target_version )
305
+ return version . parse ( _torch_version ) >= version .parse (target_version )
267
306
268
307
269
308
@contextmanager
0 commit comments