Skip to content

Commit 2fb5ea5

Browse files
authored
Fix is_torch_tpu_available in ORT Trainer (#2028)
1 parent bf1befd commit 2fb5ea5

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

optimum/onnxruntime/trainer.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,14 @@
103103
from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
104104

105105
if check_if_transformers_greater("4.39"):
106-
from transformers.utils import is_torch_xla_available
106+
from transformers.utils import is_torch_xla_available as is_torch_tpu_xla_available
107107

108-
if is_torch_xla_available():
108+
if is_torch_tpu_xla_available():
109109
import torch_xla.core.xla_model as xm
110110
else:
111-
from transformers.utils import is_torch_tpu_available
111+
from transformers.utils import is_torch_tpu_available as is_torch_tpu_xla_available
112112

113-
if is_torch_tpu_available(check_device=False):
113+
if is_torch_tpu_xla_available(check_device=False):
114114
import torch_xla.core.xla_model as xm
115115

116116
if TYPE_CHECKING:
@@ -735,7 +735,7 @@ def get_dataloader_sampler(dataloader):
735735

736736
if (
737737
args.logging_nan_inf_filter
738-
and not is_torch_tpu_available()
738+
and not is_torch_tpu_xla_available()
739739
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
740740
):
741741
# if loss is nan or inf simply add the average of previous logged losses

0 commit comments

Comments
 (0)