We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 705f5b3 commit 4de93d1Copy full SHA for 4de93d1
torch/_subclasses/fake_tensor.py
@@ -32,7 +32,7 @@
32
TypeVar,
33
Union,
34
)
35
-from typing_extensions import Self, TypeIs
+from typing_extensions import Self, TypeGuard, TypeIs
36
from weakref import ReferenceType
37
38
import torch
@@ -1214,7 +1214,7 @@ def reset_nt_tensor_id_counter(self) -> None:
1214
# In this case, it's insufficient to test only one FakeTensor: you need
1215
# to distinguish between our fake tensor and other fake tensors. That's
1216
# what this function does.
1217
- def is_our_fake(self, t: object) -> TypeIs[FakeTensor]:
+ def is_our_fake(self, t: object) -> TypeGuard[FakeTensor]:
1218
return isinstance(t, FakeTensor) and t.fake_mode is self
1219
1220
# If we should avoid device init. This changes the behavior of various APIs:
0 commit comments