Skip to content

Commit fb0f285

Browse files
suopytorchmergebot
authored andcommitted
[lint] upgrade mypy to latest version
Fixes pytorch#75927. Had to fix some bugs and add some ignores. To check if clean: ``` lintrunner --paths-cmd='git grep -Il .' --take MYPY,MYPYSTRICT ``` Pull Request resolved: pytorch#76753 Approved by: https://github.com/malfet
1 parent 8473173 commit fb0f285

File tree

30 files changed

+72
-62
lines changed

30 files changed

+72
-62
lines changed

.github/scripts/lint_native_functions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def fn(base: str) -> str:
2727
contents = f.read()
2828

2929
yaml = ruamel.yaml.YAML() # type: ignore[attr-defined]
30-
yaml.preserve_quotes = True
31-
yaml.width = 1000
32-
yaml.boolean_representation = ['False', 'True']
30+
yaml.preserve_quotes = True # type: ignore[assignment]
31+
yaml.width = 1000 # type: ignore[assignment]
32+
yaml.boolean_representation = ['False', 'True'] # type: ignore[attr-defined]
3333
r = yaml.load(contents)
3434

3535
# Cuz ruamel's author intentionally didn't include conversion to string

.lintrunner.toml

+8-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,14 @@ init_command = [
123123
'--dry-run={{DRYRUN}}',
124124
'numpy==1.20',
125125
'expecttest==0.1.3',
126-
'mypy==0.812',
126+
'mypy==0.950',
127+
'types-requests==2.27.25',
128+
'types-six==1.16.15',
129+
'types-PyYAML==6.0.7',
130+
'types-tabulate==0.8.8',
131+
'types-protobuf==3.19.18',
132+
'types-pkg-resources==0.1.3',
133+
'types-Jinja2==2.11.9',
127134
'junitparser==2.1.1',
128135
'rich==10.9.0',
129136
'pyyaml==6.0',

benchmarks/instruction_counts/core/expand.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import re
1010
import textwrap
11-
from typing import cast, List, Optional, Tuple, TYPE_CHECKING
11+
from typing import List, Optional, Tuple, TYPE_CHECKING
1212
import uuid
1313

1414
import torch
@@ -63,15 +63,12 @@ def _generate_torchscript_file(model_src: str, name: str) -> Optional[str]:
6363

6464
# Import magic to actually load our function.
6565
module_spec = importlib.util.spec_from_file_location(f"torchscript__{name}", module_path)
66+
assert module_spec is not None
6667
module = importlib.util.module_from_spec(module_spec)
6768
loader = module_spec.loader
6869
assert loader is not None
6970

70-
# Module.loader has type Optional[_Loader]. Even when we assert loader is
71-
# not None and MyPy narrows it to type _Loader, it will not pass type
72-
# checks. So we have to use a cast to tell MyPy that _Loader implements
73-
# importlib.abc.Loader.
74-
cast(importlib.abc.Loader, loader).exec_module(module)
71+
loader.exec_module(module)
7572

7673
# And again, the type checker has no way of knowing that this line is valid.
7774
jit_model = module.jit_model # type: ignore[attr-defined]

tools/setup_helpers/cmake.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def generate(
342342
cmake_prefix_path = build_options.get("CMAKE_PREFIX_PATH", None)
343343
if cmake_prefix_path:
344344
build_options["CMAKE_PREFIX_PATH"] = (
345-
cast(str, py_lib_path) + ";" + cast(str, cmake_prefix_path)
345+
py_lib_path + ";" + cast(str, cmake_prefix_path)
346346
)
347347
else:
348348
build_options["CMAKE_PREFIX_PATH"] = py_lib_path

tools/shared/module_loader.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ def import_module(name: str, path: str) -> ModuleType:
77
import importlib.util
88

99
spec = importlib.util.spec_from_file_location(name, path)
10+
assert spec is not None
1011
module = importlib.util.module_from_spec(spec)
1112
cast(Loader, spec.loader).exec_module(module)
1213
return module

tools/stats/s3_stat_parser.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def get_S3_object_from_bucket(bucket_name: str, object: str) -> Any:
102102

103103
def case_status(case: Version1Case) -> Status:
104104
for k in {"errored", "failed", "skipped"}:
105-
if case[k]: # type: ignore[misc]
105+
if case[k]: # type: ignore[literal-required]
106106
return cast(Status, k)
107107
return None
108108

tools/test/test_stats.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def make_report_v2(
8585
}
8686
files[file_name] = {
8787
"suites": suites,
88-
"total_seconds": sum(suite["total_seconds"] for suite in suites.values()),
88+
"total_seconds": sum(suite["total_seconds"] for suite in suites.values()), # type: ignore[type-var]
8989
}
9090
return {
9191
**dummy_meta_meta(), # type: ignore[misc]

torch/_C/__init__.pyi.in

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class Stream:
5454
class Size(Tuple[_int, ...]):
5555
# TODO: __reduce__
5656

57-
@overload
57+
@overload # type: ignore[override]
5858
def __getitem__(self: Size, key: _int) -> _int: ...
5959

6060
@overload

torch/_deploy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def persistent_load(saved_id):
8282
importer = sys_importer
8383

8484
unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
85-
unpickler.persistent_load = persistent_load
85+
unpickler.persistent_load = persistent_load # type: ignore[assignment]
8686
result = _deploy_objects[id] = unpickler.load()
8787
return result
8888

torch/ao/nn/sparse/quantized/linear.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -169,22 +169,22 @@ def from_float(cls, mod):
169169
assert hasattr(mod, 'sparse_params'), \
170170
('Expecting the Linear to have `sparse_params`. Make sure you have provided arguments '
171171
'in the `sparsifier.squash_mask(params_to_save=("sparse_block_shape",))` method.')
172-
sparse_block_shape = mod.sparse_params.get('sparse_block_shape', None)
172+
sparse_block_shape = mod.sparse_params.get('sparse_block_shape', None) # type: ignore[operator, union-attr]
173173
assert isinstance(sparse_block_shape, (tuple, list))
174174
assert len(sparse_block_shape) == 2
175175
# TODO: Need to add options to qconfig to avoid the calibration.
176176
# TODO: Add calibration for the sparsity
177177
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
178178
activation_post_process = mod.activation_post_process
179-
weight_post_process = mod.qconfig.weight()
179+
weight_post_process = mod.qconfig.weight() # type: ignore[operator, union-attr]
180180

181181
# Assumption is that the weight is already sparsified by the
182182
# `sparsifier.convert`
183183
weight = mod.weight
184184

185185
weight_post_process(weight)
186186
dtype = weight_post_process.dtype
187-
act_scale, act_zp = activation_post_process.calculate_qparams()
187+
act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[operator, union-attr]
188188
assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
189189
w_sc, w_zp = weight_post_process.calculate_qparams()
190190
if isinstance(w_zp, torch.Tensor):
@@ -193,15 +193,15 @@ def from_float(cls, mod):
193193
assert w_zp == 0, 'Weight zero point must map to 0'
194194
qweight = _quantize_weight(weight.float(), weight_post_process)
195195

196-
row_block_size = mod.sparse_params['sparse_block_shape'][0]
197-
col_block_size = mod.sparse_params['sparse_block_shape'][1]
196+
row_block_size = mod.sparse_params['sparse_block_shape'][0] # type: ignore[index]
197+
col_block_size = mod.sparse_params['sparse_block_shape'][1] # type: ignore[index]
198198
qlinear = cls(mod.in_features,
199199
mod.out_features,
200200
row_block_size,
201201
col_block_size,
202202
dtype=dtype)
203203
qlinear.set_weight_bias(qweight, mod.bias,
204-
row_block_size, col_block_size)
204+
row_block_size, col_block_size) # type: ignore[arg-type]
205205
qlinear.scale = float(act_scale)
206206
qlinear.zero_point = int(act_zp)
207207
return qlinear

torch/autograd/functional.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ def vjp(grad_output):
686686
raise RuntimeError(msg)
687687
jac_i_el.append(torch.zeros_like(inp_el))
688688

689-
jacobian += (tuple(torch.stack(jac_i_el, dim=0).view(out.size()
689+
jacobian += (tuple(torch.stack(jac_i_el, dim=0).view(out.size() # type: ignore[operator]
690690
+ inputs[el_idx].size()) for (el_idx, jac_i_el) in enumerate(jac_i)), )
691691

692692
jacobian = _grad_postprocess(jacobian, create_graph)

torch/distributed/distributed_c10d.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1880,7 +1880,7 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
18801880
if my_rank == src:
18811881
object_tensor = torch.cat(tensor_list)
18821882
else:
1883-
object_tensor = torch.empty(
1883+
object_tensor = torch.empty( # type: ignore[call-overload]
18841884
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
18851885
dtype=torch.uint8,
18861886
)

torch/distributed/elastic/multiprocessing/api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(self, msg: str, sigval: signal.Signals) -> None:
4747
self.sigval = sigval
4848

4949

50-
def _terminate_process_handler(signum: int, frame: FrameType) -> None:
50+
def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None:
5151
"""Termination handler that raises exceptions on the main process.
5252
5353
When the process receives death signal(SIGTERM, SIGINT), this termination handler will

torch/fx/experimental/unification/variable.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __new__(cls, *token):
1919
token = token[0]
2020

2121
obj = object.__new__(cls)
22-
obj.token = token
22+
obj.token = token # type: ignore[attr-defined]
2323
return obj
2424

2525
def __str__(self):

torch/fx/graph.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def create_name(self, candidate: str, obj: Optional[Any]) -> str:
151151
num += 1
152152
candidate = f'{base}_{num}'
153153

154-
self._used_names.setdefault(candidate)
154+
self._used_names.setdefault(candidate, 0)
155155
if obj is None:
156156
self._unassociated_names.add(candidate)
157157
else:

torch/fx/graph_module.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,10 @@ def __init__(self, cls, cls_call):
240240
def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
241241
# auxiliary variables (for readability)
242242
err_lineno = frame_summary.lineno
243-
err_line_len = len(frame_summary.line)
243+
assert err_lineno is not None
244+
line = frame_summary.line
245+
assert line is not None
246+
err_line_len = len(line)
244247
all_src_lines = linecache.getlines(frame_summary.filename)
245248

246249
# constituent substrings of the error message
@@ -260,7 +263,7 @@ def __call__(self, obj, *args, **kwargs):
260263
if self.cls_call is not None:
261264
return self.cls_call(obj, *args, **kwargs)
262265
else:
263-
return super(self.cls, obj).__call__(*args, **kwargs)
266+
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
264267
except Exception as e:
265268
assert e.__traceback__
266269
topmost_framesummary: traceback.FrameSummary = \
@@ -638,7 +641,7 @@ def recompile(self) -> PythonCode:
638641
cls_call = cls.__call__ if "__call__" in vars(cls) else None
639642

640643
if '_wrapped_call' not in vars(cls):
641-
cls._wrapped_call = _WrappedCall(cls, cls_call)
644+
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
642645

643646
def call_wrapped(self, *args, **kwargs):
644647
return self._wrapped_call(self, *args, **kwargs)

torch/hub.py

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def _import_module(name, path):
7272
import importlib.util
7373
from importlib.abc import Loader
7474
spec = importlib.util.spec_from_file_location(name, path)
75+
assert spec is not None
7576
module = importlib.util.module_from_spec(spec)
7677
assert isinstance(spec.loader, Loader)
7778
spec.loader.exec_module(module)

torch/jit/_script.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import copy
1414
import pickle
1515
import warnings
16-
from typing import Any, Dict, List, Tuple, Union, Callable
16+
from typing import Any, Dict, List, Set, Tuple, Union, Callable
1717

1818

1919
import torch
@@ -249,7 +249,7 @@ def __init__(cls, name, bases, attrs): # noqa: B902
249249
for base in reversed(bases):
250250
for k, v in getattr(base, "_methods", {}).items():
251251
cls._methods[k] = v
252-
base_constants = getattr(base, "_constants_set", set())
252+
base_constants: Set = getattr(base, "_constants_set", set())
253253
cls._constants_set = cls._constants_set.union(base_constants)
254254

255255
# find all the script methods of the current class
@@ -417,7 +417,7 @@ def __getattr__(self, attr):
417417
return super(RecursiveScriptClass, self).__getattr__(attr) # type: ignore[misc]
418418

419419
if attr in self._props:
420-
return self._props[attr].fget()
420+
return self._props[attr].fget() # type: ignore[call-arg, misc]
421421

422422
return getattr(self._c, attr)
423423

@@ -426,7 +426,7 @@ def __setattr__(self, attr, value):
426426
return super(RecursiveScriptClass, self).__setattr__(attr, value)
427427

428428
if attr in self._props:
429-
return self._props[attr].fset(value)
429+
return self._props[attr].fset(value) # type: ignore[call-arg, misc]
430430

431431
setattr(self._c, attr, value)
432432

@@ -1306,7 +1306,7 @@ def forward(self, a) -> MyModule:
13061306
qualified_name = _qualified_name(obj)
13071307
# this is a decorated fn, and we need to the underlying fn and its rcb
13081308
if hasattr(obj, "__script_if_tracing_wrapper"):
1309-
obj = obj.__original_fn
1309+
obj = obj.__original_fn # type: ignore[union-attr]
13101310
_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
13111311

13121312
# some functions are explicitly marked as not supported in script mode

torch/jit/frontend.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -537,9 +537,8 @@ def build_AnnAssign(ctx, stmt):
537537
raise UnsupportedNodeError(ctx, stmt, reason='without assigned value')
538538

539539
# Disallow type annotations on instance attributes outside of __init__
540-
if type(stmt.target) == ast.Attribute and\
541-
stmt.target.value.id == "self" and\
542-
ctx.funcname != "__init__":
540+
if type(stmt.target) == ast.Attribute and \
541+
stmt.target.value.id == "self" and ctx.funcname != "__init__": # type: ignore[attr-defined]
543542
start = stmt.col_offset
544543
end = start + len(f"self.{stmt.target.attr}")
545544
if hasattr(stmt.annotation, 'id'):

torch/nn/modules/module.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1943,6 +1943,6 @@ def _replicate_for_data_parallel(self):
19431943
replica._parameters = OrderedDict()
19441944
replica._buffers = replica._buffers.copy()
19451945
replica._modules = replica._modules.copy()
1946-
replica._is_replica = True
1946+
replica._is_replica = True # type: ignore[assignment]
19471947

19481948
return replica

torch/nn/utils/prune.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def add_pruning_method(self, method):
307307
+ " Found '{}'".format(method._tensor_name)
308308
)
309309
# if all checks passed, add to _pruning_methods tuple
310-
self._pruning_methods += (method,)
310+
self._pruning_methods += (method,) # type: ignore[operator]
311311

312312
def __len__(self):
313313
return len(self._pruning_methods)

torch/package/_package_pickler.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@ def __init__(self, importer: Importer, *args, **kwargs):
2424
# is imported, then the offending library removes its dispatch entries,
2525
# leaving PackagePickler with a stale dispatch table that may cause
2626
# unwanted behavior.
27-
self.dispatch = _Pickler.dispatch.copy()
28-
self.dispatch[FunctionType] = PackagePickler.save_global
27+
self.dispatch = _Pickler.dispatch.copy() # type: ignore[misc]
28+
self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment]
2929

3030
def save_global(self, obj, name=None):
3131
# unfortunately the pickler code is factored in a way that
3232
# forces us to copy/paste this function. The only change is marked
3333
# CHANGED below.
34-
write = self.write
35-
memo = self.memo
34+
write = self.write # type: ignore[attr-defined]
35+
memo = self.memo # type: ignore[attr-defined]
3636

3737
# CHANGED: import module from module environment instead of __import__
3838
try:
@@ -44,7 +44,7 @@ def save_global(self, obj, name=None):
4444
_, parent = _getattribute(module, name)
4545
# END CHANGED
4646

47-
if self.proto >= 2:
47+
if self.proto >= 2: # type: ignore[attr-defined]
4848
code = _extension_registry.get((module_name, name))
4949
if code:
5050
assert code > 0
@@ -59,13 +59,13 @@ def save_global(self, obj, name=None):
5959
if parent is module:
6060
name = lastname
6161
# Non-ASCII identifiers are supported only with protocols >= 3.
62-
if self.proto >= 4:
63-
self.save(module_name)
64-
self.save(name)
62+
if self.proto >= 4: # type: ignore[attr-defined]
63+
self.save(module_name) # type: ignore[attr-defined]
64+
self.save(name) # type: ignore[attr-defined]
6565
write(STACK_GLOBAL)
6666
elif parent is not module:
67-
self.save_reduce(getattr, (parent, lastname))
68-
elif self.proto >= 3:
67+
self.save_reduce(getattr, (parent, lastname)) # type: ignore[attr-defined]
68+
elif self.proto >= 3: # type: ignore[attr-defined]
6969
write(
7070
GLOBAL
7171
+ bytes(module_name, "utf-8")
@@ -74,7 +74,7 @@ def save_global(self, obj, name=None):
7474
+ b"\n"
7575
)
7676
else:
77-
if self.fix_imports:
77+
if self.fix_imports: # type: ignore[attr-defined]
7878
r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
7979
r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
8080
if (module_name, name) in r_name_mapping:
@@ -92,10 +92,10 @@ def save_global(self, obj, name=None):
9292
except UnicodeEncodeError:
9393
raise PicklingError(
9494
"can't pickle global identifier '%s.%s' using "
95-
"pickle protocol %i" % (module, name, self.proto)
95+
"pickle protocol %i" % (module, name, self.proto) # type: ignore[attr-defined]
9696
) from None
9797

98-
self.memoize(obj)
98+
self.memoize(obj) # type: ignore[attr-defined]
9999

100100

101101
def create_pickler(data_buf, importer, protocol=4):

torch/package/_package_unpickler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, importer: Importer, *args, **kwargs):
1717

1818
def find_class(self, module, name):
1919
# Subclasses may override this.
20-
if self.proto < 3 and self.fix_imports:
20+
if self.proto < 3 and self.fix_imports: # type: ignore[attr-defined]
2121
if (module, name) in _compat_pickle.NAME_MAPPING:
2222
module, name = _compat_pickle.NAME_MAPPING[(module, name)]
2323
elif module in _compat_pickle.IMPORT_MAPPING:

torch/package/package_importer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def persistent_load(saved_id):
237237
# Load the data (which may in turn use `persistent_load` to load tensors)
238238
data_file = io.BytesIO(self.zip_reader.get_record(pickle_file))
239239
unpickler = self.Unpickler(data_file)
240-
unpickler.persistent_load = persistent_load
240+
unpickler.persistent_load = persistent_load # type: ignore[assignment]
241241

242242
@contextmanager
243243
def set_deserialization_context():

0 commit comments

Comments
 (0)