Skip to content

Commit 86c7652

Browse files
shunting314pytorchmergebot
authored andcommitted
[inductor] layout optimization for conv (pytorch#99773)
convolution kernel with channels last runs much faster then kernel with contiguous inputs. The PR leverage that to optimize tensor layouts so we provide 'channels last' inputs to convolution. Some care need to be taken to not convert tensor layout between contiguous and channels last back and forth. Those extra copies hurt performance quite much. Latest perf number [here](https://hud.pytorch.org/benchmark/compilers?startTime=Wed%2C%2024%20May%202023%2023%3A40%3A37%20GMT&stopTime=Wed%2C%2031%20May%202023%2023%3A40%3A37%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=shunting-layout-opt-19&lCommit=baa797fc100688dfb044fbcbdebcfd2591710f78&rBranch=main&rCommit=999bae0f54108ffc5b7cf2524a02a83901554b16) - TB: 1.64x -> 1.69x - HF: 1.79x -> 1.78x (random noise) - TIMM: 1.51x -> 1.65x Right now we disable layout optimization for dynamic shape since there is perf loss in that combination. Here is a GH issue to followup: pytorch#102670 Pull Request resolved: pytorch#99773 Approved by: https://github.com/jansel
1 parent 4da8844 commit 86c7652

File tree

15 files changed

+637
-21
lines changed

15 files changed

+637
-21
lines changed

benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,4 @@ timm_vovnet,pass,8
5050
tts_angular,pass,10
5151
vgg16,pass,8
5252
vision_maskrcnn,fail_accuracy,167
53-
yolov3,pass,10
53+
yolov3,pass,11

test/inductor/test_layout_optim.py

+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Owner(s): ["module: inductor"]
2+
import copy
3+
import os
4+
5+
import torch
6+
from torch import nn
7+
from torch._dynamo.test_case import run_tests, TestCase
8+
from torch._dynamo.utils import same
9+
from torch.testing._internal.common_utils import TEST_WITH_ROCM
10+
from torch.testing._internal.inductor_utils import HAS_CUDA
11+
12+
USE_DDP_WRAPPER = os.environ.get("USE_DDP_WRAPPER", "1") == "1"
13+
14+
15+
class Model2Conv(nn.Module):
16+
def __init__(self, dim=512, manual_graph_break=False):
17+
super().__init__()
18+
self.conv1 = nn.Conv2d(3, dim, kernel_size=3, stride=2, bias=False)
19+
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=2, bias=False)
20+
self.manual_graph_break = manual_graph_break
21+
22+
def forward(self, x):
23+
x = self.conv1(x)
24+
if self.manual_graph_break:
25+
torch._dynamo.graph_break()
26+
x = self.conv2(x)
27+
return x
28+
29+
def get_example_inputs(self):
30+
return (torch.rand(2, 3, 16, 16),)
31+
32+
33+
class TestLayoutOptim(TestCase):
34+
@classmethod
35+
def setUpClass(cls):
36+
super().setUpClass()
37+
38+
import torch.distributed as dist
39+
40+
port = 10001
41+
dist.init_process_group(
42+
backend="nccl", init_method=f"tcp://localhost:{port}", world_size=1, rank=0
43+
)
44+
45+
def verify_accuracy(
46+
self, model_class, use_ddp_wrapper=USE_DDP_WRAPPER, is_train=False
47+
):
48+
# there are 2 potential ways to introduce graph breaks
49+
# 1. manually
50+
# 2. using DDP
51+
# if we are not using DDP to introduce graph breaks, do that manually
52+
def wrap_mod(m):
53+
if is_train:
54+
55+
def f(*inp):
56+
x = m(*inp)
57+
x.sum().backward()
58+
59+
grads = []
60+
for name, param in m.named_parameters():
61+
grad = param.grad
62+
if param.grad is None:
63+
grad = torch.zeros_like(param)
64+
grads.append(grad)
65+
return grads
66+
67+
return f
68+
else:
69+
return m
70+
71+
manual_graph_break = not use_ddp_wrapper
72+
mod = model_class(manual_graph_break=manual_graph_break).cuda()
73+
inp = [t.cuda() for t in mod.get_example_inputs()]
74+
expected_out = wrap_mod(mod)(*inp)
75+
76+
fp64_mod = copy.deepcopy(mod).to(torch.float64)
77+
fp64_inp = [t.to(torch.float64) for t in copy.deepcopy(inp)]
78+
fp64_out = wrap_mod(fp64_mod)(*fp64_inp)
79+
80+
if use_ddp_wrapper:
81+
from torch.nn.parallel import DistributedDataParallel as DDP
82+
83+
ddp_wrapped_mod = DDP(mod)
84+
opt_mod = torch.compile(wrap_mod(ddp_wrapped_mod))
85+
else:
86+
opt_mod = torch.compile(wrap_mod(mod))
87+
actual_out = opt_mod(*inp)
88+
89+
if is_train:
90+
self.assertTrue(same(expected_out, actual_out, fp64_ref=fp64_out))
91+
else:
92+
expected_sum = expected_out.sum()
93+
actual_sum = actual_out.sum()
94+
print(f"Expected sum {expected_sum}, actual sum {actual_sum}")
95+
self.assertTrue(same(expected_out, actual_out, fp64_ref=fp64_out))
96+
97+
def verify_accuracy_for_infer(self, *args, **kwargs):
98+
self.verify_accuracy(*args, **kwargs, is_train=False)
99+
100+
def verify_accuracy_for_train(self, *args, **kwargs):
101+
self.verify_accuracy(*args, **kwargs, is_train=True)
102+
103+
def test_2conv_with_graph_break(self):
104+
"""
105+
Make sure graph break does not cause any accuracy issue.
106+
"""
107+
self.verify_accuracy_for_infer(Model2Conv)
108+
109+
def test_3conv_with_graph_break(self):
110+
class Model(nn.Module):
111+
def __init__(
112+
self, dim=512, patch_size=7, kernel_size=7, manual_graph_break=False
113+
):
114+
super().__init__()
115+
self.seq = nn.Sequential(
116+
nn.Conv2d(
117+
3, dim, kernel_size=patch_size, stride=patch_size, bias=False
118+
),
119+
nn.Conv2d(
120+
dim, dim, kernel_size, groups=dim, padding="same", bias=False
121+
),
122+
)
123+
self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=False)
124+
self.manual_graph_break = manual_graph_break
125+
126+
def forward(self, x):
127+
x = self.seq(x)
128+
if self.manual_graph_break:
129+
torch._dynamo.graph_break()
130+
x = self.conv(x)
131+
return x
132+
133+
def get_example_inputs(self):
134+
return (torch.randn(2, 3, 16, 16),)
135+
136+
self.verify_accuracy_for_infer(Model)
137+
138+
def test_keep_output_layout_infer(self):
139+
class Model(nn.Module):
140+
def __init__(self):
141+
super().__init__()
142+
self.conv = nn.Conv2d(
143+
3, 128, kernel_size=3, padding=1, stride=1, bias=False
144+
)
145+
146+
def forward(self, x):
147+
x = self.conv(x)
148+
return x
149+
150+
def get_example_inputs(self):
151+
return (torch.randn(2, 3, 5, 5),)
152+
153+
mod = Model().cuda()
154+
inp = [t.cuda() for t in mod.get_example_inputs()]
155+
out = mod(*inp)
156+
157+
opt_mod = torch.compile(mod)
158+
opt_out = opt_mod(*inp)
159+
160+
# We should be able to do view on eager output
161+
out.view(5, -1)
162+
163+
# We should be able to do view on the output of the optimized module
164+
# Note that if the output is channels last, the view op will fail.
165+
opt_out.view(5, -1)
166+
167+
def test_training_acc(self):
168+
self.verify_accuracy_for_train(Model2Conv)
169+
170+
171+
if __name__ == "__main__":
172+
if HAS_CUDA and not TEST_WITH_ROCM:
173+
run_tests()

test/test_fake_tensor.py

+7
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,13 @@ def test_aten_slice_scatter_multi_device(self):
632632
self.checkType(r3, "cpu", (4, 4))
633633
self.checkType(out, "cpu", (4, 4))
634634

635+
def test__adaptive_avg_pool2d_backward(self):
636+
with FakeTensorMode():
637+
grad_out = torch.rand(2, 3, 4, 4)
638+
inp = torch.rand(2, 3, 4, 4).to(memory_format=torch.channels_last)
639+
grad_in = torch.ops.aten._adaptive_avg_pool2d_backward(grad_out, inp)
640+
self.assertTrue(torch._prims_common.suggest_memory_format(grad_in) == torch.channels_last)
641+
635642

636643
class FakeTensorConstHandling(TestCase):
637644
def assertConst(self, *args):

torch/_functorch/aot_autograd.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -2765,6 +2765,14 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
27652765
# We are not clearing flat_args here because
27662766
# 1) There is a check in the the debug compiler at the end
27672767
# 2) It does not matter as these are fake tensors
2768+
2769+
# the compiler need to use this field to find the original modol outputs
2770+
# from the AOTAutograd fwd module's outputs. Thus compiler can make sure
2771+
# optimizations like layout optimization does not change those tensors'
2772+
# layout.
2773+
# TODO once https://github.com/pytorch/pytorch/pull/100652/files#r1212002707 is in
2774+
# change to access fw_metadata from the global tracing context.
2775+
fw_module.meta["original_output_start_index"] = fw_metadata.num_mutated_inputs
27682776
compiled_fw_func = aot_config.fw_compiler(
27692777
fw_module, adjusted_flat_args
27702778
)
@@ -2981,9 +2989,24 @@ def call_compiled_backward():
29812989
if CompiledFunction.compiled_bw is None:
29822990
assert all(a is not None for a in all_args)
29832991
context = torch._C._DisableAutocast if disable_amp else nullcontext
2992+
2993+
placeholder_list = fx_placeholder_vals(bw_module)
2994+
2995+
# saved activations can have different stride to eager if
2996+
# the compiler does layout optimization. We should restride the
2997+
# tensor passed in for compiling the backward graph using the
2998+
# saved tensor's stride.
2999+
for i in range(len(placeholder_list)):
3000+
ph_arg = placeholder_list[i]
3001+
real_arg = all_args[i]
3002+
if not isinstance(ph_arg, torch.Tensor):
3003+
continue
3004+
if ph_arg.stride() != real_arg.stride():
3005+
placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_arg.stride())
3006+
29843007
with tracing(saved_context), context(), track_graph_compiling(aot_config, "backward"):
29853008
CompiledFunction.compiled_bw = aot_config.bw_compiler(
2986-
bw_module, fx_placeholder_vals(bw_module)
3009+
bw_module, placeholder_list
29873010
)
29883011

29893012
ctx.maybe_clear_saved_tensors()

torch/_inductor/codegen/triton.py

+77-4
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@
2525
get_fused_kernel_name,
2626
get_kernel_category_by_source_code,
2727
get_kernel_metadata,
28+
green_text,
2829
next_power_of_2,
2930
sympy_product,
3031
sympy_subs,
3132
sympy_symbol,
3233
unique,
34+
yellow_text,
3335
)
3436
from ..virtualized import ops, V
3537

@@ -1425,19 +1427,22 @@ def codegen_kernel_benchmark(self):
14251427
with result.indent():
14261428
name_cnt = itertools.count()
14271429
var_names = []
1428-
for arg_name in call_args:
1430+
for arg_name, arg_sig in zip(call_args, signature):
14291431
var_name = f"arg_{next(name_cnt)}"
14301432
buf = V.graph.get_buffer(arg_name)
14311433
if buf:
14321434
result.writeline(
1433-
f"{var_name} = rand_strided({tuple(buf.get_size())}, {tuple(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
1435+
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
14341436
)
14351437
elif arg_name in V.graph.constants:
14361438
# note that random seed is put in V.graph.constants
14371439
const_tensor = V.graph.constants[arg_name]
14381440
result.writeline(
1439-
f"{var_name} = rand_strided({tuple(const_tensor.size())}, {tuple(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # noqa: B950 line too long
1441+
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # noqa: B950 line too long
14401442
)
1443+
elif isinstance(arg_sig, SizeArg):
1444+
symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
1445+
result.writeline(f"{var_name} = {symval_hint}")
14411446
else:
14421447
raise KeyError(
14431448
f"Don't find the buffer or const tensor for {arg_name}"
@@ -1457,7 +1462,7 @@ def codegen_kernel_benchmark(self):
14571462
f"torch.cuda.set_device({index})"
14581463
) # no-op to ensure context
14591464
for tree in self.range_trees:
1460-
expr = pexpr(tree.numel)
1465+
expr = pexpr(V.graph.sizevars.size_hint(tree.numel))
14611466
if tree.prefix != "r" or self.inside_reduction:
14621467
extra_args.append(expr)
14631468
if tree.prefix != "r":
@@ -1730,6 +1735,71 @@ def call_kernel(self, name: str):
17301735
V.graph.scheduler.current_device.index,
17311736
)
17321737

1738+
def warn_mix_layout(self, kernel_name):
1739+
"""
1740+
Print message if the kernel have mixed layout inputs.
1741+
Only care about 4D tensor for now.
1742+
"""
1743+
if (
1744+
len(self.args.input_buffers) == 1
1745+
and len(self.args.output_buffers) == 1
1746+
and len(self.args.inplace_buffers) == 0
1747+
):
1748+
# even if input buffer and output buffer have different layout,
1749+
# this can be a layout conversion kernel. No need to warn for
1750+
# the mix layouts.
1751+
return
1752+
1753+
argdefs, call_args, signature = self.args.python_argdefs()
1754+
uniform_stride_order = None
1755+
for arg_name in call_args:
1756+
buf = V.graph.get_buffer(arg_name)
1757+
if buf and len(buf.layout.size) == 4:
1758+
# ignore the tensor if only 1 dimention is non-zero
1759+
if len([x for x in buf.layout.size if x == 1]) == 3:
1760+
continue
1761+
stride_order = ir.get_stride_order(buf.layout.stride)
1762+
if uniform_stride_order is None:
1763+
uniform_stride_order = stride_order
1764+
elif uniform_stride_order != stride_order:
1765+
msg = yellow_text(
1766+
f"Expected stride order {uniform_stride_order}, but found stride order"
1767+
+ f" {stride_order} for kernel {kernel_name}"
1768+
)
1769+
log.warning(msg)
1770+
1771+
stride_order_list = [
1772+
ir.get_stride_order(V.graph.get_buffer(name).layout.stride)
1773+
if V.graph.get_buffer(name)
1774+
else None
1775+
for name in call_args
1776+
]
1777+
size_list = [
1778+
V.graph.get_buffer(name).layout.size
1779+
if V.graph.get_buffer(name)
1780+
else None
1781+
for name in call_args
1782+
]
1783+
source_list = [
1784+
"GraphInput"
1785+
if name in V.graph.graph_inputs
1786+
else "IntermediateBuffer"
1787+
if name in V.graph.name_to_buffer
1788+
else None
1789+
for name in call_args
1790+
]
1791+
1792+
msg = yellow_text(
1793+
f" param names {argdefs}\n buf names {call_args}\n strides {stride_order_list}"
1794+
+ f"\n sizes {size_list}\n sources {source_list}\n"
1795+
)
1796+
log.warning(msg)
1797+
return
1798+
msg = green_text(
1799+
f"All the inputs for the triton kernel {kernel_name} have uniform layout"
1800+
)
1801+
log.warning(msg)
1802+
17331803
def create_cse_var(self, *args, **kwargs):
17341804
return TritonCSEVariable(*args, **kwargs)
17351805

@@ -2014,6 +2084,9 @@ def codegen_node_schedule(self, node_schedule, numel, reduction_numel):
20142084

20152085
kernel.call_kernel(kernel_name)
20162086

2087+
if config.warn_mix_layout:
2088+
kernel.warn_mix_layout(kernel_name)
2089+
20172090
if (
20182091
V.graph.wrapper_code.supports_intermediate_hooks
20192092
and config.generate_intermediate_hooks

torch/_inductor/codegen/wrapper.py

+11
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,14 @@ def get_output_refs(self):
342342
def mark_output_type(self):
343343
return
344344

345+
def codegen_input_size_asserts(self):
346+
for name, buf in V.graph.graph_inputs.items():
347+
if isinstance(buf, sympy.Expr):
348+
continue
349+
size = self.codegen_shape_tuple(buf.get_size())
350+
stride = self.codegen_shape_tuple(buf.get_stride())
351+
self.prefix.writeline(f"assert_size_stride({name}, {size}, {stride})")
352+
345353
def write_prefix(self):
346354
self.prefix.splice(
347355
"""
@@ -360,7 +368,10 @@ def call(args):
360368
lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}"
361369
self.prefix.writeline(f"{lhs} = args")
362370
self.prefix.writeline("args.clear()")
371+
363372
self.codegen_inputs(self.prefix, V.graph.graph_inputs)
373+
if config.size_asserts:
374+
self.codegen_input_size_asserts()
364375

365376
def write_get_cuda_stream(self, index):
366377
self.write_triton_header_once()

0 commit comments

Comments
 (0)