Skip to content

Commit 16caa8c

Browse files
Skylion007pytorchmergebot
authored andcommitted
[BE]: Update Typeguard to TypeIs for better type inference (pytorch#133814)
Uses TypeIs instead of TypeGuard for better inference. See https://peps.python.org/pep-0742/ Pull Request resolved: pytorch#133814 Approved by: https://github.com/ezyang
1 parent 9bb327b commit 16caa8c

File tree

12 files changed

+26
-26
lines changed

12 files changed

+26
-26
lines changed

.ci/docker/requirements-ci.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ tb-nightly==2.13.0a20230426
257257
#test that import:
258258

259259
# needed by torchgen utils
260-
typing-extensions
260+
typing-extensions>=4.10.0
261261
#Description: type hints for python
262262
#Pinned versions:
263263
#test that import:

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ requires = [
77
"ninja",
88
"pyyaml",
99
"cmake",
10-
"typing-extensions",
10+
"typing-extensions>=4.10.0",
1111
"requests",
1212
]
1313
# Use legacy backend to import local packages in setup.py

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ requests
1111
# is required until pytorch build not refactored to work for latest setuptools.
1212
setuptools<=72.1.0
1313
types-dataclasses
14-
typing-extensions>=4.8.0
14+
typing-extensions>=4.10.0
1515
sympy==1.13.1 ; python_version >= "3.9"
1616
filelock
1717
networkx

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,7 @@ def main():
11591159
)
11601160
install_requires = [
11611161
"filelock",
1162-
"typing-extensions>=4.8.0",
1162+
"typing-extensions>=4.10.0",
11631163
'setuptools ; python_version >= "3.12"',
11641164
'sympy==1.13.1 ; python_version >= "3.9"',
11651165
"networkx",

torch/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
TypeVar as _TypeVar,
3535
Union as _Union,
3636
)
37-
from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard
37+
from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs
3838

3939

4040
if TYPE_CHECKING:
@@ -1008,7 +1008,7 @@ def typename(obj: _Any, /) -> str:
10081008
return f"{module}.{qualname}"
10091009

10101010

1011-
def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]:
1011+
def is_tensor(obj: _Any, /) -> _TypeIs["torch.Tensor"]:
10121012
r"""Returns True if `obj` is a PyTorch tensor.
10131013
10141014
Note that this function is simply doing ``isinstance(obj, Tensor)``.
@@ -1028,7 +1028,7 @@ def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]:
10281028
return isinstance(obj, torch.Tensor)
10291029

10301030

1031-
def is_storage(obj: _Any, /) -> _TypeGuard[_Union["TypedStorage", "UntypedStorage"]]:
1031+
def is_storage(obj: _Any, /) -> _TypeIs[_Union["TypedStorage", "UntypedStorage"]]:
10321032
r"""Returns True if `obj` is a PyTorch storage object.
10331033
10341034
Args:

torch/_dynamo/utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
Union,
5757
ValuesView,
5858
)
59-
from typing_extensions import Literal, TypeGuard
59+
from typing_extensions import Literal, TypeIs
6060

6161
import torch
6262
import torch._functorch.config
@@ -569,14 +569,14 @@ def clear(self):
569569

570570

571571
@overload
572-
def istype(obj: object, allowed_types: Type[T]) -> TypeGuard[T]:
572+
def istype(obj: object, allowed_types: Type[T]) -> TypeIs[T]:
573573
...
574574

575575

576576
@overload
577577
def istype(
578578
obj: object, allowed_types: Tuple[Type[List[T]], Type[Tuple[T, ...]]]
579-
) -> TypeGuard[T]:
579+
) -> TypeIs[T]:
580580
...
581581

582582

torch/_inductor/pattern_matcher.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
TypeVar,
7171
Union,
7272
)
73-
from typing_extensions import Self, TypeGuard
73+
from typing_extensions import Self, TypeIs
7474

7575
import torch
7676
import torch._guards
@@ -305,10 +305,10 @@ def __bool__(self) -> bool:
305305
MatchResult = Union[Match, FailedMatch]
306306

307307

308-
def is_match(m: MatchResult) -> TypeGuard[Match]:
308+
def is_match(m: MatchResult) -> TypeIs[Match]:
309309
"""
310-
TypeGuards cannot act on `self`. Thus this function exists to let mypy
311-
recognize FailedMatch.__bool__ as a TypeGuard.
310+
TypeIs cannot act on `self`. Thus this function exists to let mypy
311+
recognize FailedMatch.__bool__ as a TypeIs.
312312
"""
313313
return bool(m)
314314

torch/_subclasses/fake_tensor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
TypeVar,
3333
Union,
3434
)
35-
from typing_extensions import Self, TypeGuard
35+
from typing_extensions import Self, TypeIs
3636
from weakref import ReferenceType
3737

3838
import torch
@@ -169,7 +169,7 @@ def get_plain_tensors(subclass: Tensor) -> List[Tensor]:
169169
return plain_tensors
170170

171171

172-
def is_fake(x: object) -> TypeGuard[Tensor]:
172+
def is_fake(x: object) -> TypeIs[Tensor]:
173173
if isinstance(x, FakeTensor):
174174
return True
175175
if is_traceable_wrapper_subclass(x):
@@ -1213,7 +1213,7 @@ def reset_nt_tensor_id_counter(self) -> None:
12131213
# In this case, it's insufficient to test only one FakeTensor: you need
12141214
# to distinguish between our fake tensor and other fake tensors. That's
12151215
# what this function does.
1216-
def is_our_fake(self, t: object) -> TypeGuard[FakeTensor]:
1216+
def is_our_fake(self, t: object) -> TypeIs[FakeTensor]:
12171217
return isinstance(t, FakeTensor) and t.fake_mode is self
12181218

12191219
# If we should avoid device init. This changes the behavior of various APIs:

torch/masked/maskedtensor/core.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import warnings
55
from typing import Any
6-
from typing_extensions import TypeGuard
6+
from typing_extensions import TypeIs
77

88
import torch
99
from torch.overrides import get_default_nowrap_functions
@@ -15,7 +15,7 @@
1515
]
1616

1717

18-
def is_masked_tensor(obj: Any, /) -> TypeGuard["MaskedTensor"]:
18+
def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]:
1919
r"""Returns True if the input is a MaskedTensor, else False
2020
2121
Args:

torch/nn/parameter.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# mypy: allow-untyped-defs
2-
from typing_extensions import TypeGuard
2+
from typing_extensions import TypeIs
33

44
from torch import device, dtype, Tensor
55

@@ -8,7 +8,7 @@ class Parameter(Tensor):
88

99
def is_lazy(
1010
param: Tensor,
11-
) -> TypeGuard[UninitializedParameter | UninitializedBuffer]: ...
11+
) -> TypeIs[UninitializedParameter | UninitializedBuffer]: ...
1212

1313
class UninitializedParameter(Tensor):
1414
def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ...

torch/serialization.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
Type,
2929
Union,
3030
)
31-
from typing_extensions import TypeAlias, TypeGuard # Python 3.10+
31+
from typing_extensions import TypeAlias, TypeIs
3232

3333
import torch
3434
import torch._weights_only_unpickler as _weights_only_unpickler
@@ -620,7 +620,7 @@ def storage_to_tensor_type(storage):
620620
return getattr(module, storage_type.__name__.replace("Storage", "Tensor"))
621621

622622

623-
def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
623+
def _is_path(name_or_buffer) -> TypeIs[Union[str, os.PathLike]]:
624624
return isinstance(name_or_buffer, (str, os.PathLike))
625625

626626

torch/utils/_python_dispatch.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55
from dataclasses import dataclass
66
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque, Type
7-
from typing_extensions import TypeGuard
7+
from typing_extensions import TypeIs
88
from collections import deque
99

1010
import torch
@@ -365,7 +365,7 @@ def to(
365365

366366

367367

368-
def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]:
368+
def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
369369
"""
370370
Returns whether or not a tensor subclass that implements __torch_dispatch__
371371
is 'traceable' with torch.compile.
@@ -402,7 +402,7 @@ def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]:
402402
and hasattr(t, "__tensor_unflatten__")
403403
)
404404

405-
def is_traceable_wrapper_subclass_type(t: Type) -> TypeGuard[Type[TensorWithFlatten]]:
405+
def is_traceable_wrapper_subclass_type(t: Type) -> TypeIs[Type[TensorWithFlatten]]:
406406
"""Same as above, but takes a type argument instead of an instance."""
407407
return (issubclass(t, torch.Tensor) and t != torch.Tensor
408408
and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"))

0 commit comments

Comments
 (0)