File tree 1 file changed +5
-5
lines changed
1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change 103
103
from transformers .deepspeed import deepspeed_init , deepspeed_load_checkpoint , is_deepspeed_zero3_enabled
104
104
105
105
if check_if_transformers_greater ("4.39" ):
106
- from transformers .utils import is_torch_xla_available
106
+ from transformers .utils import is_torch_xla_available as is_torch_tpu_xla_available
107
107
108
- if is_torch_xla_available ():
108
+ if is_torch_tpu_xla_available ():
109
109
import torch_xla .core .xla_model as xm
110
110
else :
111
- from transformers .utils import is_torch_tpu_available
111
+ from transformers .utils import is_torch_tpu_available as is_torch_tpu_xla_available
112
112
113
- if is_torch_tpu_available (check_device = False ):
113
+ if is_torch_tpu_xla_available (check_device = False ):
114
114
import torch_xla .core .xla_model as xm
115
115
116
116
if TYPE_CHECKING :
@@ -735,7 +735,7 @@ def get_dataloader_sampler(dataloader):
735
735
736
736
if (
737
737
args .logging_nan_inf_filter
738
- and not is_torch_tpu_available ()
738
+ and not is_torch_tpu_xla_available ()
739
739
and (torch .isnan (tr_loss_step ) or torch .isinf (tr_loss_step ))
740
740
):
741
741
# if loss is nan or inf simply add the average of previous logged losses
You can’t perform that action at this time.
0 commit comments