forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_compute_comm_reordering.py
396 lines (367 loc) · 16.9 KB
/
test_compute_comm_reordering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
# Owner(s): ["module: inductor"]
import unittest
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.logging
import torch._dynamo.test_case
# for some reason importing functional collectives after dynamo breaks collectives handling!
import torch.distributed._functional_collectives as _functional_collectives
from torch._C import FileCheck
from torch._dynamo.utils import same
from torch._inductor import ir, scheduler
from torch._inductor.comm_analysis import (
baseLat,
hwLat,
llMaxBws,
NCCL_ALGO,
NCCL_HW,
NCCL_PROTO,
NVIDIA_GPU_TYPE,
)
from torch._inductor.utils import run_and_get_triton_code
from torch.testing._internal.common_distributed import (
_dynamo_dist_per_rank_init,
at_least_x_gpu,
DynamoDistributedMultiProcTestCase,
requires_nccl,
)
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.inductor_utils import HAS_GPU
def get_snode_runtime_for_reorder_compute_test(snode):
# NOTE: custom cost model to show that the compute reordering algorithm is working
# Collective kernels
if isinstance(snode.node, ir._CollectiveKernel):
return 100
elif isinstance(snode.node, ir._WaitKernel):
return 0
# High-arithmetic-intensity compute kernels
elif isinstance(snode.node, ir.ExternKernel):
return 5
# All other kernels
return 1
def create_grouped_node_for_allreduce_and_its_deps(snodes):
name_to_snode = {snode.node.name: snode for snode in snodes}
all_reduce_snodes = [
snode
for snode in snodes
if isinstance(snode.node, ir._CollectiveKernel)
and snode.node.op_overload == torch.ops._c10d_functional.all_reduce_.default
]
assert len(all_reduce_snodes) == 1
all_reduce_snode = all_reduce_snodes[0]
all_reduce_dep_snodes = [
name_to_snode[node.name] for node in all_reduce_snode.node.inputs
]
assert len(all_reduce_dep_snodes) == 1
all_reduce_dep_snode = all_reduce_dep_snodes[0]
grouped_snode = scheduler.GroupedSchedulerNode.create(
[all_reduce_dep_snode, all_reduce_snode]
)
new_snode_order = []
new_snode_order.append(grouped_snode)
for snode in snodes:
if snode in grouped_snode.snodes:
continue
new_snode_order.append(snode)
return new_snode_order
@requires_nccl()
class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
"""
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
"""
def get_world_trs(self):
return {
"tag": "",
"ranks": list(range(self.world_size)),
"group_size": self.world_size,
}
@property
def world_size(self) -> int:
# hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2
# works around issue with skipif<2 and workers with unpredictable #s gpu
return 2
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_locality", False)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
@patch.object(
torch._inductor.config,
"reorder_for_compute_comm_overlap_passes",
[
"sink_waits",
],
)
def test_sink_waits(self):
def func(a):
ar = _functional_collectives.all_reduce(a, "sum", "0")
b = torch.matmul(a, a)
return torch.matmul(ar, b)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs)
# Verify that the wait_tensor is sinked below the 1st matmul but
# above the 2nd matmul.
(
FileCheck()
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("extern_kernels.mm")
.run(code)
)
out = compiled(inputs)
correct = func(inputs)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_locality", False)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
@patch.object(
torch._inductor.config,
"reorder_for_compute_comm_overlap_passes",
[
"raise_comms",
],
)
def test_raise_comms(self):
def func(a):
b = torch.matmul(a, a)
c = torch.relu(b)
d = torch.matmul(c, c)
e = _functional_collectives.all_reduce(b, "sum", "0")
return torch.matmul(d, e)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs)
# Verify that the all_reduce_ has been raised above the 2nd matmul
# but below the 1st matmul. Note that the all_reduce_ directly
# writes to the output buffer of the 1st matmul, which is an input
# to the first relu. Therefore, the all_reduce_ should be scheduled
# after the first relu.
(
FileCheck()
.check("extern_kernels.mm")
.check("triton_poi_fused_relu")
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("extern_kernels.mm")
.run(code)
)
out = compiled(inputs)
correct = func(inputs)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
@patch.object(
torch._inductor.config,
"reorder_for_compute_comm_overlap_passes",
[
"sink_waits",
"raise_comms",
],
)
def test_sink_waits_raise_comms(self):
def func(a, *, tag, ranks, group_size):
b = torch.matmul(a, a)
c = torch.relu(b)
d = torch.matmul(c, c)
e = _functional_collectives.all_reduce(b, "sum", "0")
f = torch.relu(d)
g = torch.matmul(f, f)
return torch.mm(e, g)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# Things to verify:
# - The clone prologue of the all_reduce_ should not be fused with
# any relus.
# - The all_reduce_ and its prologue should be raised above the 2nd
# matmul but below the 1st matmul.
# - The wait_tensor should be sinked below the 3rd matmul but above
# the 4th matmul.
(
FileCheck()
.check("extern_kernels.mm")
.check("triton_poi_fused_all_reduce_0")
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("triton_poi_fused_relu")
.check("extern_kernels.mm")
.check("triton_poi_fused_relu")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("extern_kernels.mm")
.run(code)
)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
@patch.object(
torch._inductor.config,
"reorder_for_compute_comm_overlap_passes",
[
"reorder_compute_for_overlap",
],
)
def test_reorder_compute_for_overlap(self):
def func(a, *, tag, ranks, group_size):
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
g = torch.matmul(a, a)
c = torch.relu(a)
d = torch.matmul(c, c)
f = d * c * ar
fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
e = torch.matmul(d + ar + fr, g)
return (e,)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: after scheduling the first all_reduce:
# 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce.
# 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce.
# 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce.
# and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce.
(
FileCheck()
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("triton_poi_fused_relu")
.check("extern_kernels.mm")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("triton_poi_fused_all_reduce_mul")
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("triton_poi_fused_add")
.check("extern_kernels.mm")
.run(code)
)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
@patch.object(
torch._inductor.config,
"reorder_for_compute_comm_overlap_passes",
[
"reorder_compute_for_overlap",
],
)
@patch.object(
torch._inductor.config,
"estimate_op_runtime",
get_snode_runtime_for_reorder_compute_test,
)
def test_reorder_compute_for_overlap_custom_runtime_estimation(self):
def func(a, *, tag, ranks, group_size):
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
g = torch.matmul(a, a)
c = torch.relu(a)
d = torch.matmul(c, c)
f = d * c * ar
fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
e = torch.matmul(d + ar + fr, g)
return (e,)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: after scheduling the first all_reduce:
# 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce.
# 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce.
# 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce.
# and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce.
(
FileCheck()
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("triton_poi_fused_relu")
.check("extern_kernels.mm")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("triton_poi_fused_all_reduce_mul")
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("triton_poi_fused_add")
.check("extern_kernels.mm")
.run(code)
)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skipIfRocm
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(
torch._inductor.config,
"_pre_fusion_custom_pass",
create_grouped_node_for_allreduce_and_its_deps,
)
def test_grouped_scheduler_node(self):
def func(a, *, tag, ranks, group_size):
add = a + a
div = add / a
ar = _functional_collectives.all_reduce(div, "sum", ranks, tag)
# Normally, we would fuse `add = a + a`, `div = add / a` and `mul = a * a` together into a single fused op,
# but here in this unit test, we intentionally put `add`, `div` and `ar` computation
# into a GroupedSchedulerNode, which prevents them from being fused with any other ops.
mul = a * a
mm = torch.matmul(mul, ar)
return (mm,)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# Expectations:
# 1. `add = a + a` and `div = add / a` are still fused, which means fusion
# still happens among nodes within a GroupedSchedulerNode.
# 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within
# GroupedSchedulerNode and thus are prevented from being fused with any outside ops.
FileCheck().check("triton_poi_fused_add_all_reduce_div_0.").check(
"_c10d_functional.all_reduce_."
).check("triton_poi_fused_mul_1.").run(code)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
def test_nccl_heuristics(self):
assert len(baseLat) == len(NCCL_ALGO)
assert all(len(x) == len(NCCL_PROTO) for x in baseLat)
assert len(hwLat) == len(NCCL_HW)
assert all(len(x) == len(NCCL_ALGO) for x in hwLat)
assert all(len(y) == len(NCCL_PROTO) for x in hwLat for y in x)
assert len(llMaxBws) == len(NVIDIA_GPU_TYPE)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()