Skip to content

Commit 040b925

Browse files
janselpytorchmergebot
authored andcommitted
[Compiled Autograd] Reorder accumulate grad nodes (pytorch#121735)
Pull Request resolved: pytorch#121735 Approved by: https://github.com/xmfan
1 parent f0b9a83 commit 040b925

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

test/inductor/test_compiled_autograd.py

+2
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,8 @@ def wrap_test_class(orig_cls):
12181218
"test_backward_tensorlist_input_requires_list_grads_none_or_Tensor", # AssertionError: "None or Tensor"
12191219
"test_backward_tensorlist_input_requires_list_grads_with_same_numel", # AssertionError: "3 gradients
12201220
"test_save_for_backward_inputs_are_namedtuple", # torch._dynamo.exc.Unsupported: 'skip function
1221+
"test_autograd_function_backed_op", # RuntimeError: compiled_args not implemented
1222+
"test_setitem", # AssertionError: Tensor-likes are not close!
12211223
}
12221224

12231225
if not HAS_CUDA:

torch/_dynamo/compiled_autograd.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import contextlib
22
import functools
3-
from typing import List, Optional
3+
import itertools
4+
from typing import Dict, List, Optional
45

56
import torch
67
from torch._dynamo.external_utils import call_backward, call_hook
@@ -195,6 +196,7 @@ def end_capture(self, outputs):
195196
(self.fx_tracer.create_arg(self.to_proxy(outputs)),),
196197
{},
197198
)
199+
self.reorder_accumulate_grad_nodes()
198200
graph = GraphModule(
199201
self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd"
200202
)
@@ -207,6 +209,24 @@ def end_capture(self, outputs):
207209
)
208210
return self.compiler_fn(graph)
209211

212+
def reorder_accumulate_grad_nodes(self):
213+
"""
214+
Usage of AOTAutograd causes all the accumulate_grad_ nodes to get pushed to the end of
215+
the graph. This differs from eager mode, which schedules them as soon as possible. This
216+
pass attempts to reorder the graph to mimic eager behavior.
217+
"""
218+
order: Dict[torch.fx.Node, int] = {}
219+
counter = itertools.count()
220+
target = torch.ops.inductor.accumulate_grad_.default
221+
last = None
222+
for node in [*self.fx_tracer.graph.nodes]:
223+
if node.op == "call_function" and node.target == target:
224+
arg = max(node.args, key=order.get) # type: ignore[arg-type]
225+
if arg is not last:
226+
arg.append(node)
227+
order[node] = next(counter)
228+
last = node
229+
210230
def to_proxy(self, t):
211231
if t is None:
212232
return None

0 commit comments

Comments
 (0)