Skip to content

Commit 08db735

Browse files
Skylion007pytorchmergebot
authored andcommitted
[BE]: Update mypy to 1.13.0 (pytorch#140808)
Update mypy to 1.13.0 . Should hopefully reduce linting time. Has support for orjson cache serialization which should improve mypy cache perf if orjson is installed. Pull Request resolved: pytorch#140808 Approved by: https://github.com/ezyang, https://github.com/malfet
1 parent 34127fc commit 08db735

File tree

31 files changed

+114
-71
lines changed

31 files changed

+114
-71
lines changed

.ci/docker/requirements-ci.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ librosa>=0.6.2 ; python_version < "3.11"
9090
#Pinned versions:
9191
#test that import:
9292

93-
mypy==1.11.2
93+
mypy==1.13.0
9494
# Pin MyPy version because new errors are likely to appear with each release
9595
#Description: linter
9696
#Pinned versions: 1.10.0

.lintrunner.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ init_command = [
144144
'numpy==1.26.4 ; python_version >= "3.9" and python_version <= "3.11"',
145145
'numpy==2.1.0 ; python_version >= "3.12"',
146146
'expecttest==0.2.1',
147-
'mypy==1.11.2',
147+
'mypy==1.13.0',
148148
'sympy==1.13.0 ; python_version >= "3.9"',
149149
'types-requests==2.27.25',
150150
'types-PyYAML==6.0.7',

tools/flight_recorder/components/types.py

+2
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,11 @@ def __init__(
386386
}, f"{type} is not a supported operation"
387387
self.type = type
388388
if type == "send":
389+
assert isinstance(meta, str)
389390
s, d = meta.split("->")
390391
self._src, self._dst = int(s), int(d)
391392
elif type == "recv":
393+
assert isinstance(meta, str)
392394
d, s = meta.split("<-")
393395
self._dst, self._src = int(d), int(s)
394396
else:

torch/_dynamo/eval_frame.py

+1
Original file line numberDiff line numberDiff line change
@@ -1503,6 +1503,7 @@ def result_capturing_wrapper(*graph_inputs):
15031503
# NB: this is wrong if graph_captured_result has
15041504
# data-dependent output size!
15051505
ignore_fresh_unbacked = null_context()
1506+
assert ambient_fake_mode is not None
15061507
if shape_env := ambient_fake_mode.shape_env:
15071508
ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols()
15081509

torch/_dynamo/output_graph.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,18 @@
1111
import traceback
1212
import weakref
1313
from dataclasses import dataclass
14-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
14+
from typing import (
15+
Any,
16+
Callable,
17+
cast,
18+
Dict,
19+
List,
20+
Optional,
21+
Set,
22+
Tuple,
23+
TYPE_CHECKING,
24+
Union,
25+
)
1526

1627
import sympy
1728

@@ -621,8 +632,11 @@ def save_global_state(self, out=None):
621632
"""
622633
Saves to out if it is provided. Else saves to the tracing context's global_state.
623634
"""
624-
global_state = (
625-
out if out is not None else self.tracing_context.global_context.global_state
635+
global_state = cast(
636+
Dict[str, Tuple[Callable[..., Any], bool]],
637+
out
638+
if out is not None
639+
else self.tracing_context.global_context.global_state,
626640
)
627641

628642
# TODO - Consider having a torch level API for torch_function_state. As
@@ -645,11 +659,11 @@ def save_global_state(self, out=None):
645659
functools.partial(torch.set_autocast_enabled, "cpu"),
646660
torch.is_autocast_enabled("cpu"),
647661
)
648-
global_state["autocast_gpu_dtype"] = (
662+
global_state["autocast_gpu_dtype"] = ( # type:ignore[assignment]
649663
functools.partial(torch.set_autocast_dtype, "cuda"),
650664
torch.get_autocast_dtype("cuda"),
651665
)
652-
global_state["autocast_cpu_dtype"] = (
666+
global_state["autocast_cpu_dtype"] = ( # type:ignore[assignment]
653667
functools.partial(torch.set_autocast_dtype, "cpu"),
654668
torch.get_autocast_dtype("cpu"),
655669
)

torch/_dynamo/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1108,7 +1108,7 @@ class ChromiumEventLogger:
11081108
a specification of the Chromium Event JSON format.
11091109
"""
11101110

1111-
def get_stack(self):
1111+
def get_stack(self) -> List[str]:
11121112
"""
11131113
The main event stack, with every chromium event.
11141114
Logged to tlparse.
@@ -1119,7 +1119,7 @@ def get_stack(self):
11191119
self.tls.stack = []
11201120
return self.tls.stack
11211121

1122-
def get_top(self) -> str:
1122+
def get_top(self) -> Optional[str]:
11231123
"""
11241124
Get the top event name or None if the stack is empty.
11251125
"""

torch/_export/serde/dynamic_shapes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ def _track_dim_from_dims(
166166
root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined]
167167
if root.__name__ not in dims:
168168
dims[root.__name__] = {
169-
"min": root.min,
170-
"max": root.max,
169+
"min": root.min, # type: ignore[attr-defined,union-attr]
170+
"max": root.max, # type: ignore[attr-defined,union-attr]
171171
"derived": set(),
172172
}
173173

torch/_export/serde/serialize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2423,7 +2423,7 @@ def _dict_to_dataclass(cls, data):
24232423
field_type = cls.__annotations__[_type]
24242424
return cls.create(**{_type: _dict_to_dataclass(field_type, _value)})
24252425
elif dataclasses.is_dataclass(cls):
2426-
obj = cls(**data) # type: ignore[assignment]
2426+
obj = cls(**data) # type: ignore[assignment,operator]
24272427
type_hints = typing.get_type_hints(cls)
24282428
for f in dataclasses.fields(cls):
24292429
name = f.name

torch/_inductor/codegen/cpp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def reduction_prefix_array(
292292
acc_type: str,
293293
reduction_type: str,
294294
dtype: torch.dtype,
295-
len: int,
295+
len: Union[str, int],
296296
init_fn,
297297
):
298298
"""

torch/_inductor/codegen/cpp_gemm_template.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,8 @@ def transpose_w(
308308

309309

310310
def expand_bias(
311-
B: Union[ir.IRNode, torch.Tensor], X: Union[ir.IRNode, torch.Tensor]
312-
) -> Union[ir.IRNode, torch.Tensor]:
311+
B: Union[ir.IRNode, torch.Tensor, None], X: Union[ir.IRNode, torch.Tensor]
312+
) -> Optional[Union[ir.IRNode, torch.Tensor]]:
313313
"""
314314
Expand Bias to the same size of X.
315315
"""
@@ -870,7 +870,7 @@ def normalize_shapes(inputs, layout_or_out):
870870
W = new_inputs[1]
871871
B = new_inputs[2] if has_bias else None
872872
W = transpose_w(W, trans_w)
873-
B = expand_bias(B, X)
873+
B = expand_bias(B, X) # type:ignore[arg-type]
874874
new_inputs[1] = W
875875
if B is not None:
876876
new_inputs[2] = B

torch/_inductor/compile_fx.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def split_const_gm(
382382
gm,
383383
node,
384384
(
385-
const_result[const_outputs[node.name]]
385+
const_result[const_outputs[node.name]] # type:ignore[index]
386386
if lifted_constant_names is None
387387
else None
388388
),

torch/_inductor/graph.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1883,9 +1883,9 @@ def materialize(
18831883
# Generating random inputs based on self.example_inputs sometimes can be problematic,
18841884
# e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
18851885
real_inputs = [
1886-
materialize(x)
1886+
materialize(x) # type:ignore[arg-type]
18871887
for x in (
1888-
self.example_inputs
1888+
self.example_inputs # type:ignore[union-attr]
18891889
if isinstance(V.real_inputs, NullHandler)
18901890
else V.real_inputs
18911891
)

torch/_prims_common/__init__.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1612,7 +1612,7 @@ def reduction_dtypes(
16121612
# batched_matrix_contiguous_strides and contiguous_strides
16131613
def make_contiguous_strides_for(
16141614
shape: ShapeType, row_major: bool = True
1615-
) -> Tuple[int, ...]:
1615+
) -> Tuple[Union[_IntLikeT, int], ...]:
16161616
"""
16171617
Returns the strides of a contiguous tensor if row_major
16181618
If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices
@@ -1625,11 +1625,13 @@ def make_contiguous_strides_for(
16251625

16261626
from torch.fx.experimental.symbolic_shapes import is_nested_int
16271627

1628-
multiplier = 1
1628+
multiplier: Union[_IntLikeT, int] = 1
16291629
strides = []
16301630
for l in reversed(shape):
16311631
strides.append(multiplier)
1632-
multiplier *= l if is_nested_int(l) else sym_max(l, 1)
1632+
multiplier *= (
1633+
l if is_nested_int(l) else sym_max(l, 1)
1634+
) # type:ignore[assignment]
16331635

16341636
result = tuple(reversed(strides))
16351637

torch/_refs/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def _broadcast_shapes(*_shapes):
410410
assert isinstance(shape, Sequence)
411411

412412
# Computes common shape
413-
common_shape = [
413+
common_shape: List[Union[int, torch.SymInt]] = [
414414
1,
415415
] * reduce(max, (len(shape) for shape in shapes))
416416
for arg_idx, shape in enumerate(shapes):

torch/ao/nn/qat/modules/conv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(
2020
out_channels: int,
2121
kernel_size: Tuple[int, ...],
2222
stride: Tuple[int, ...],
23-
padding: Tuple[int, ...],
23+
padding: Union[str, Tuple[int, ...]],
2424
dilation: Tuple[int, ...],
2525
transposed: bool,
2626
output_padding: Tuple[int, ...],

torch/ao/ns/fx/weight_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def get_lstm_weight(mod: nn.Module) -> List[torch.Tensor]:
3535
res = []
3636
for idx, param_name in enumerate(mod._flat_weights_names): # type: ignore[arg-type]
3737
if "weight_ih_l" in param_name or "weight_hh_l" in param_name:
38-
param_value = mod._flat_weights[idx].detach() # type: ignore[index]
38+
param_value = mod._flat_weights[idx].detach() # type: ignore[index,union-attr]
3939
res.append(param_value)
4040
return res
4141

@@ -72,7 +72,7 @@ def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]:
7272
res = []
7373
for idx, param_name in enumerate(mod._flat_weights_names):
7474
if "weight_ih_l" in param_name or "weight_hh_l" in param_name:
75-
param_value = mod._flat_weights[idx].detach()
75+
param_value = mod._flat_weights[idx].detach() # type: ignore[index,union-attr]
7676
res.append(param_value)
7777
return res
7878
else:

torch/ao/quantization/fx/prepare.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def _get_output_act_obs_or_fq(
665665
named_modules: Dict[str, torch.nn.Module],
666666
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
667667
is_qat: bool,
668-
) -> ObserverOrFakeQuantize:
668+
) -> Optional[ObserverOrFakeQuantize]:
669669
"""Get the constructor for observer or fake quant object for
670670
the argument in the original graph as the output of previous node,
671671
skipping inserted observers

torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def post_localSGD_hook(
105105
# Run allreduce using `global_group_to_use` in the first `start_localSGD_iter` iterations.
106106
if state.iter < state.start_localSGD_iter:
107107
state.maybe_increase_iter(bucket)
108-
return default._allreduce_fut(global_group_to_use, input_tensor)
108+
return default._allreduce_fut(global_group_to_use, input_tensor) # type: ignore[arg-type]
109109

110110
# If `post_local_gradient_allreduce` is not set,
111111
# then no gradient synchronization after the first `start_localSGD_iter` iterations.

torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import torch.distributed as dist
99
from torch.distributed import distributed_c10d
10+
from torch.utils._typing_utils import not_none
1011

1112
from . import default_hooks as default
1213

@@ -398,7 +399,9 @@ def powerSGD_hook(
398399
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
399400
""" # noqa: B950
400401
process_group = state.process_group
401-
group_to_use = process_group if process_group is not None else dist.group.WORLD
402+
group_to_use = (
403+
process_group if process_group is not None else not_none(dist.group.WORLD)
404+
)
402405
world_size = group_to_use.size()
403406

404407
# The input tensor is a flattened 1D tensor.
@@ -707,7 +710,9 @@ def batched_powerSGD_hook(
707710
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
708711
""" # noqa: B950
709712
process_group = state.process_group
710-
group_to_use = process_group if process_group is not None else dist.group.WORLD
713+
group_to_use = (
714+
process_group if process_group is not None else not_none(dist.group.WORLD)
715+
)
711716
world_size = group_to_use.size()
712717

713718
# The input tensor is a flattened 1D tensor.

torch/distributed/algorithms/model_averaging/averagers.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# mypy: allow-untyped-defs
22
import warnings
33
from abc import ABC, abstractmethod
4-
from typing import Dict, Iterable, Union
4+
from typing import Dict, Iterable, Optional, Union
55

66
import torch
77
import torch.distributed as dist
88
import torch.distributed.algorithms.model_averaging.utils as utils
9+
from torch.utils._typing_utils import not_none as _not_none
910

1011

1112
__all__ = ["ModelAverager", "PeriodicModelAverager"]
@@ -21,9 +22,9 @@ class ModelAverager(ABC):
2122
will be used. (default: ``None``)
2223
"""
2324

24-
def __init__(self, process_group=None):
25+
def __init__(self, process_group: Optional[dist.ProcessGroup] = None):
2526
self.process_group = (
26-
process_group if process_group is not None else dist.group.WORLD
27+
process_group if process_group is not None else _not_none(dist.group.WORLD)
2728
)
2829
self.step = 0
2930

@@ -85,7 +86,9 @@ class PeriodicModelAverager(ModelAverager):
8586
>>> averager.average_parameters(model.parameters())
8687
"""
8788

88-
def __init__(self, period, warmup_steps=0, process_group=None):
89+
def __init__(
90+
self, period, warmup_steps=0, process_group: Optional[dist.ProcessGroup] = None
91+
):
8992
super().__init__(process_group)
9093
if warmup_steps < 0:
9194
raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
@@ -120,5 +123,7 @@ def average_parameters(
120123
self.step >= self.warmup_steps
121124
and (self.step - self.warmup_steps) % self.period == 0
122125
):
123-
utils.average_parameters_or_parameter_groups(params, self.process_group)
126+
utils.average_parameters_or_parameter_groups(
127+
params, _not_none(self.process_group)
128+
)
124129
self.step += 1

torch/distributed/distributed_c10d.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -4477,7 +4477,9 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
44774477

44784478

44794479
@_exception_logger
4480-
def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
4480+
def barrier(
4481+
group: Optional[ProcessGroup] = GroupMember.WORLD, async_op=False, device_ids=None
4482+
):
44814483
"""
44824484
Synchronize all processes.
44834485
@@ -4519,7 +4521,11 @@ def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
45194521
work.wait()
45204522

45214523

4522-
def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=False):
4524+
def monitored_barrier(
4525+
group: Optional[ProcessGroup] = GroupMember.WORLD,
4526+
timeout=None,
4527+
wait_all_ranks=False,
4528+
):
45234529
"""
45244530
Synchronize processes similar to ``torch.distributed.barrier``, but consider a configurable timeout.
45254531
@@ -4589,7 +4595,9 @@ def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=Fals
45894595
_check_valid_timeout(timeout)
45904596

45914597
group_to_use = _get_default_group() if group is None else group
4592-
return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks)
4598+
return group_to_use.monitored_barrier( # type:ignore[attr-defined]
4599+
timeout, wait_all_ranks=wait_all_ranks
4600+
)
45934601

45944602

45954603
def _create_process_group_wrapper(

torch/distributed/fsdp/_optim_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ def _flatten_optim_state(
630630
assert state_names is not None
631631

632632
# Flatten the state
633-
flat_state: Dict[str, Any] = {}
633+
flat_state: Dict[str, Optional[torch.Tensor]] = {}
634634
for state_name in state_names:
635635
state_values = [
636636
unflat_param_state[state_name] if unflat_param_state is not None else None
@@ -658,7 +658,7 @@ def _flatten_optim_state(
658658
if are_pos_dim_tensors:
659659
flat_tensor = _flatten_tensor_optim_state(
660660
state_name,
661-
state_values,
661+
state_values, # type: ignore[arg-type]
662662
unflat_param_names,
663663
unflat_param_shapes,
664664
handle,
@@ -680,7 +680,7 @@ def _flatten_optim_state(
680680
elif are_zero_dim_tensors:
681681
flat_state[state_name] = _flatten_zero_dim_tensor_optim_state(
682682
state_name,
683-
state_values,
683+
state_values, # type: ignore[arg-type]
684684
unflat_param_names,
685685
)
686686
else:

0 commit comments

Comments
 (0)