Skip to content

Commit fbd7e74

Browse files
aakhundovpytorchmergebot
authored andcommitted
[inductor] Enable mypy checking in lowering.py (pytorch#105317)
Summary: As suggested in pytorch#105230, mypy checking is enabled in `torch/_inductor/lowering.py`. 23 errors fixed; 6 silenced with `# type: ignore[attr-defined]`. Test Plan: Before the fix: ``` $ mypy torch/_inductor/lowering.py torch/_inductor/lowering.py:139:16: error: "Symbol" has no attribute "is_integer" [attr-defined] torch/_inductor/lowering.py:263:20: error: Incompatible types in assignment (expression has type "Union[List[Any], Tuple[Any, ...]]", variable has type "List[Any]") [assignment] torch/_inductor/lowering.py:427:49: error: "IRNode" has no attribute "get_size" [attr-defined] torch/_inductor/lowering.py:439:37: error: "IRNode" has no attribute "get_dtype" [attr-defined] torch/_inductor/lowering.py:456:34: error: "IRNode" has no attribute "get_device" [attr-defined] torch/_inductor/lowering.py:645:44: error: Need type annotation for "b" [var-annotated] torch/_inductor/lowering.py:1321:12: error: "FakeTensor" has no attribute "is_cpu" [attr-defined] torch/_inductor/lowering.py:1542:24: error: Argument 3 to "FixedLayout" has incompatible type "List[int]"; expected "List[Expr]" [arg-type] torch/_inductor/lowering.py:1542:81: error: Argument "offset" to "FixedLayout" has incompatible type "int"; expected "Expr" [arg-type] torch/_inductor/lowering.py:1571:24: error: Argument 3 to "FixedLayout" has incompatible type "List[int]"; expected "List[Expr]" [arg-type] torch/_inductor/lowering.py:1571:81: error: Argument "offset" to "FixedLayout" has incompatible type "int"; expected "Expr" [arg-type] torch/_inductor/lowering.py:1654:12: error: Incompatible types in assignment (expression has type "List[Any]", variable has type "Tuple[Any, ...]") [assignment] torch/_inductor/lowering.py:2009:9: error: Need type annotation for "ranges" (hint: "ranges: List[<type>] = ...") [var-annotated] torch/_inductor/lowering.py:2151:16: error: Incompatible types in assignment (expression has type "List[Any]", variable has type "Tuple[Any, ...]") [assignment] torch/_inductor/lowering.py:2198:43: error: Item "type" of "Union[List[Any], type]" has no attribute "__iter__" (not iterable) [union-attr] torch/_inductor/lowering.py:2229:36: error: Argument 1 to "len" has incompatible type "Union[List[Any], type]"; expected "Sized" [arg-type] torch/_inductor/lowering.py:2231:38: error: Item "type" of "Union[List[Any], type]" has no attribute "__iter__" (not iterable) [union-attr] torch/_inductor/lowering.py:2233:35: error: Item "type" of "Union[List[Any], type]" has no attribute "__iter__" (not iterable) [union-attr] torch/_inductor/lowering.py:2569:54: error: Incompatible default for argument "reduce" (default has type "None", argument has type "str") [assignment] torch/_inductor/lowering.py:2569:54: note: PEP 484 prohibits implicit Optional. Accordingly, mypy has changed its default to no_implicit_optional=True torch/_inductor/lowering.py:2569:54: note: Use https://github.com/hauntsaninja/no_implicit_optional to automatically upgrade your codebase torch/_inductor/lowering.py:2586:59: error: Incompatible default for argument "reduce" (default has type "None", argument has type "str") [assignment] torch/_inductor/lowering.py:2586:59: note: PEP 484 prohibits implicit Optional. Accordingly, mypy has changed its default to no_implicit_optional=True torch/_inductor/lowering.py:2586:59: note: Use https://github.com/hauntsaninja/no_implicit_optional to automatically upgrade your codebase torch/_inductor/lowering.py:2720:65: error: Incompatible default for argument "scales_x" (default has type "None", argument has type "Tuple[float]") [assignment] torch/_inductor/lowering.py:2720:65: note: PEP 484 prohibits implicit Optional. Accordingly, mypy has changed its default to no_implicit_optional=True torch/_inductor/lowering.py:2720:65: note: Use https://github.com/hauntsaninja/no_implicit_optional to automatically upgrade your codebase torch/_inductor/lowering.py:2735:5: error: Name "scale" already defined on line 2731 [no-redef] torch/_inductor/lowering.py:2758:47: error: Argument 3 to "upsample_nearestnd" has incompatible type "Tuple[Optional[float]]"; expected "Tuple[float]" [arg-type] torch/_inductor/lowering.py:2765:47: error: Argument 3 to "upsample_nearestnd" has incompatible type "Tuple[Optional[float], Optional[float]]"; expected "Tuple[float]" [arg-type] torch/_inductor/lowering.py:2776:47: error: Argument 3 to "upsample_nearestnd" has incompatible type "Tuple[Optional[float], Optional[float], Optional[float]]"; expected "Tuple[float]" [arg-type] torch/_inductor/lowering.py:2949:13: error: No binding for nonlocal "grad" found [misc] torch/_inductor/lowering.py:3063:49: error: Argument 2 to "range_mask_low" has incompatible type "int"; expected "Expr" [arg-type] torch/_inductor/lowering.py:3271:48: error: "IRNode" has no attribute "data" [attr-defined] torch/_inductor/lowering.py:3272:16: error: "IRNode" has no attribute "data" [attr-defined] Found 29 errors in 1 file (checked 1 source file) ``` After the fix: ``` $ mypy torch/_inductor/lowering.py Success: no issues found in 1 source file ``` Reviewers: @eellison Subscribers: Tasks: Tags: Pull Request resolved: pytorch#105317 Approved by: https://github.com/eellison
1 parent 88f1885 commit fbd7e74

File tree

4 files changed

+59
-45
lines changed

4 files changed

+59
-45
lines changed

.lintrunner.toml

+1
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ include_patterns = [
182182
'torch/_inductor/graph.py',
183183
'torch/_inductor/codegen/wrapper.py',
184184
'torch/_inductor/cudagraph_trees.py',
185+
'torch/_inductor/lowering.py',
185186
'torch/_C/_dynamo/**/*.py',
186187
'test/test_utils.py', # used to by in MYPY but after importing op_db it took 10+ minutes
187188
]

tools/pyi/gen_pyi.py

+1
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,7 @@ def replace_special_case(hint: str) -> str:
10351035
"def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ..."
10361036
],
10371037
"_is_view": ["def _is_view(self) -> _bool: ..."],
1038+
"is_cpu": ["is_cpu: _bool"],
10381039
"is_cuda": ["is_cuda: _bool"],
10391040
"is_leaf": ["is_leaf: _bool"],
10401041
"is_nested": ["is_nested: _bool"],

torch/_inductor/ir.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1818,9 +1818,9 @@ def __init__(
18181818
self,
18191819
device: torch.device,
18201820
dtype: torch.dtype,
1821-
size: List[Expr],
1822-
stride: List[Expr] = None,
1823-
offset: Expr = Integer(0),
1821+
size: Union[List[Expr], List[int]],
1822+
stride: Optional[Union[List[Expr], List[int]]] = None,
1823+
offset: Union[Expr, int] = Integer(0),
18241824
):
18251825
if stride is None:
18261826
stride = FlexibleLayout.contiguous_strides(size)
@@ -3116,7 +3116,7 @@ def __init__(
31163116
index,
31173117
src,
31183118
*,
3119-
reduce: str = None,
3119+
reduce: Optional[str] = None,
31203120
include_self: bool = True,
31213121
):
31223122
assert fn in {"aten.scatter_", "aten.scatter_reduce_"}

torch/_inductor/lowering.py

+53-41
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
from collections import defaultdict
77
from collections.abc import Iterable
8-
from typing import List, Optional, Tuple
8+
from typing import Any, List, Optional, Tuple, Union
99

1010
import sympy
1111

@@ -130,7 +130,7 @@ def is_integer_type(x):
130130
if isinstance(x, TensorBox):
131131
return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
132132
elif isinstance(x, sympy.Symbol):
133-
return x.is_integer is True
133+
return x.is_integer is True # type: ignore[attr-defined]
134134
else:
135135
return isinstance(x, int)
136136

@@ -249,7 +249,7 @@ def _register_lowering(
249249

250250
@functools.wraps(decomp_fn)
251251
def wrapped(*args, **kwargs):
252-
args = list(args)
252+
args: Union[List[Any], Tuple[Any, ...]] = list(args)
253253
unpacked = False
254254
# TODO maybe we need to use pytrees here
255255
if len(args) == 1 and isinstance(args[0], (list, tuple)):
@@ -418,7 +418,7 @@ def inner(*inputs: List[List[TensorBox]], alpha=1):
418418
def is_dynamic(*args):
419419
return any(
420420
isinstance(t, TensorBox)
421-
and any(x.free_symbols for x in t.data.get_size())
421+
and any(x.free_symbols for x in t.data.get_size()) # type: ignore[attr-defined]
422422
for t in args
423423
)
424424

@@ -430,7 +430,7 @@ def has_type_promotion(*args):
430430
for t in args:
431431
if isinstance(t, TensorBox):
432432
if dtype is None:
433-
dtype = t.data.get_dtype()
433+
dtype = t.data.get_dtype() # type: ignore[attr-defined]
434434
elif dtype != t.data.get_dtype():
435435
return True
436436
return False
@@ -447,7 +447,7 @@ def group_args(arg_pairs):
447447
device = None
448448
for t in args:
449449
if isinstance(t, TensorBox):
450-
device = t.data.get_device()
450+
device = t.data.get_device() # type: ignore[attr-defined]
451451
break
452452
assert (
453453
device is not None
@@ -630,8 +630,8 @@ def fn(*args):
630630
def broadcast_tensors(*inputs):
631631
if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)):
632632
return broadcast_tensors(*inputs[0])
633-
target = functools.reduce(
634-
broadcast_symbolic_shapes, [x.get_size() for x in inputs], ()
633+
target: List[sympy.Expr] = functools.reduce(
634+
broadcast_symbolic_shapes, [x.get_size() for x in inputs], []
635635
)
636636
outputs = []
637637
for x in inputs:
@@ -1645,7 +1645,9 @@ def apply_constraint(arg, fx_arg):
16451645
return ir.ExternKernel.require_stride_order(arg, stride_order)
16461646
return arg
16471647

1648-
args = [apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)]
1648+
args = tuple(
1649+
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
1650+
)
16491651
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
16501652
return args, kwargs
16511653

@@ -1999,21 +2001,21 @@ def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False):
19992001
else:
20002002
dtype = dtype or torch.get_default_dtype()
20012003

2004+
ranges: List[sympy.Expr] = []
2005+
20022006
if isinstance(data, sympy.Expr):
2003-
ranges = []
20042007

20052008
def inner_fn(index):
20062009
return ops.index_expr(data, dtype)
20072010

20082011
elif isinstance(data, (float, int)):
2009-
ranges = []
20102012

20112013
def inner_fn(index):
20122014
return ops.constant(data, dtype)
20132015

20142016
elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8:
20152017
# inline small tensors
2016-
ranges = [sympy.Integer(len(data))]
2018+
ranges.append(sympy.Integer(len(data)))
20172019

20182020
def inner_fn(index):
20192021
def binary_search(start, end):
@@ -2142,7 +2144,7 @@ def empty(
21422144
assert memory_format in (None, torch.contiguous_format)
21432145
device = decode_device(device)
21442146
if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
2145-
size = list(size[0])
2147+
size = tuple(size[0])
21462148
return empty_strided(
21472149
size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
21482150
)
@@ -2184,7 +2186,7 @@ def new_constant(fill_value):
21842186
def _new_constant(
21852187
x, size, *, dtype=None, layout=None, device=None, pin_memory=None
21862188
):
2187-
assert isinstance(size, (list, type))
2189+
assert isinstance(size, (list, tuple))
21882190
assert not pin_memory
21892191
assert layout in (None, torch.strided)
21902192
dtype = decode_dtype(dtype) or x.get_dtype()
@@ -2210,8 +2212,8 @@ def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None)
22102212
def empty_strided(
22112213
size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
22122214
):
2213-
assert isinstance(size, (list, type))
2214-
assert isinstance(stride, (list, type, type(None)))
2215+
assert isinstance(size, (list, tuple))
2216+
assert isinstance(stride, (list, tuple, type(None)))
22152217
assert not pin_memory
22162218
assert layout in (None, torch.strided)
22172219
dtype = decode_dtype(dtype) or torch.get_default_dtype()
@@ -2560,7 +2562,14 @@ def scatter(x, dim: int, index, src, **kwargs):
25602562

25612563

25622564
def scatter_fallback(
2563-
fn, self, dim: int, index, src, *, reduce: str = None, include_self: bool = True
2565+
fn,
2566+
self,
2567+
dim: int,
2568+
index,
2569+
src,
2570+
*,
2571+
reduce: Optional[str] = None,
2572+
include_self: bool = True,
25642573
):
25652574
reduce_ty = "add" if fn == "aten.scatter_" else "sum"
25662575
if (
@@ -2577,7 +2586,7 @@ def scatter_fallback(
25772586

25782587

25792588
@register_lowering(aten.scatter_, type_promotion_kind=None)
2580-
def scatter_(self, dim: int, index, src, *, reduce: str = None):
2589+
def scatter_(self, dim: int, index, src, *, reduce: Optional[str] = None):
25812590
assert reduce in {None, "add", "multiply"}
25822591

25832592
fallback_result = scatter_fallback(
@@ -2711,7 +2720,9 @@ def backend_reduce_str(reduce):
27112720
return self
27122721

27132722

2714-
def upsample_nearestnd(x, output_size, scales_x: Tuple[float] = None, n: int = 2):
2723+
def upsample_nearestnd(
2724+
x, output_size, scales_x: Tuple[Optional[float], ...], n: int = 2
2725+
):
27152726
x.realize_hint() # elements are reused
27162727
x_loader = x.make_loader()
27172728
i_sizes = x.get_size()[-n:]
@@ -2726,7 +2737,7 @@ def upsample_nearestnd(x, output_size, scales_x: Tuple[float] = None, n: int = 2
27262737
if scale:
27272738
scales[i] = scale
27282739

2729-
def scale(x, scale, size):
2740+
def scale_fn(x, scale, size):
27302741
x = ops.index_expr(x, torch.float32)
27312742
x = ops.mul(x, ops.constant(scale, torch.float32))
27322743
x = ops.to_dtype(x, torch.int32)
@@ -2736,7 +2747,7 @@ def fn(idx):
27362747
x = idx[-n:]
27372748
b = idx[:-n]
27382749
return x_loader(
2739-
[*b, *[scale(i, s, size) for i, s, size in zip(x, scales, i_sizes)]]
2750+
[*b, *[scale_fn(i, s, size) for i, s, size in zip(x, scales, i_sizes)]]
27402751
)
27412752

27422753
return Pointwise.create(
@@ -2939,23 +2950,6 @@ def index_range_condition(index_range):
29392950
ub = ops.index_expr(ub, torch.int64)
29402951
return ops.and_(ops.ge(i, lb), ops.le(i, ub))
29412952

2942-
def accumulate(out_x, out_y, index_range1, index_range2=None):
2943-
nonlocal grad
2944-
2945-
# If the upper bound is less than the lower bound, we can get rid of one accumulation.
2946-
# This happens when the padding size is zero.
2947-
upper_less_than_lower1 = index_range1[2] < index_range1[1]
2948-
if isinstance(upper_less_than_lower1, bool) and upper_less_than_lower1:
2949-
return
2950-
cond = index_range_condition(index_range1)
2951-
if index_range2 is not None:
2952-
upper_less_than_lower2 = index_range2[2] < index_range2[1]
2953-
if isinstance(upper_less_than_lower2, bool) and upper_less_than_lower2:
2954-
return
2955-
cond = ops.and_(cond, index_range_condition(index_range2))
2956-
g = ops.masked(cond, lambda: load_from_output(out_x, out_y), 0.0)
2957-
grad = ops.add(grad, g)
2958-
29592953
# Areas after reflection:
29602954
#
29612955
# top-left | top | top-right
@@ -2978,6 +2972,24 @@ def accumulate(out_x, out_y, index_range1, index_range2=None):
29782972
index_range_condition(range_cx), index_range_condition(range_cy)
29792973
)
29802974
grad = ops.masked(cond, lambda: load_from_output(center_x, center_y), 0.0)
2975+
2976+
def accumulate(out_x, out_y, index_range1, index_range2=None):
2977+
nonlocal grad
2978+
2979+
# If the upper bound is less than the lower bound, we can get rid of one accumulation.
2980+
# This happens when the padding size is zero.
2981+
upper_less_than_lower1 = index_range1[2] < index_range1[1]
2982+
if isinstance(upper_less_than_lower1, bool) and upper_less_than_lower1:
2983+
return
2984+
cond = index_range_condition(index_range1)
2985+
if index_range2 is not None:
2986+
upper_less_than_lower2 = index_range2[2] < index_range2[1]
2987+
if isinstance(upper_less_than_lower2, bool) and upper_less_than_lower2:
2988+
return
2989+
cond = ops.and_(cond, index_range_condition(index_range2))
2990+
g = ops.masked(cond, lambda: load_from_output(out_x, out_y), 0.0)
2991+
grad = ops.add(grad, g)
2992+
29812993
accumulate(center_x, left_reflect_y, range_cx, (y, 1, left))
29822994
accumulate(center_x, right_reflect_y, range_cx, (y, w - right, w - 1))
29832995
accumulate(top_reflect_x, center_y, (x, 1, top), range_cy)
@@ -3076,7 +3088,7 @@ def offset_fn(index):
30763088
)
30773089

30783090

3079-
def range_mask_low(i: sympy.Expr, low: sympy.Expr):
3091+
def range_mask_low(i: sympy.Expr, low: Union[sympy.Expr, int]):
30803092
return ops.ge(
30813093
ops.index_expr(i, torch.int64),
30823094
ops.index_expr(sympy.Integer(low), torch.int64),
@@ -3262,8 +3274,8 @@ def max_pool2d_with_indices_backward(
32623274
# some classes don't have `get_stride`
32633275
# TODO will need a better way of determining if inputs are channels-last
32643276
gO_stride = None
3265-
if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise):
3266-
data = x.data.data
3277+
if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise): # type: ignore[attr-defined]
3278+
data = x.data.data # type: ignore[attr-defined]
32673279
x_buffer = ir.ComputedBuffer(
32683280
name=None,
32693281
layout=ir.FlexibleLayout(

0 commit comments

Comments
 (0)