1
1
import contextlib
2
2
import functools
3
- from typing import List , Optional
3
+ import itertools
4
+ from typing import Dict , List , Optional
4
5
5
6
import torch
6
7
from torch ._dynamo .external_utils import call_backward , call_hook
@@ -195,6 +196,7 @@ def end_capture(self, outputs):
195
196
(self .fx_tracer .create_arg (self .to_proxy (outputs )),),
196
197
{},
197
198
)
199
+ self .reorder_accumulate_grad_nodes ()
198
200
graph = GraphModule (
199
201
self .fx_tracer .root , self .fx_tracer .graph , "CompiledAutograd"
200
202
)
@@ -207,6 +209,24 @@ def end_capture(self, outputs):
207
209
)
208
210
return self .compiler_fn (graph )
209
211
212
+ def reorder_accumulate_grad_nodes (self ):
213
+ """
214
+ Usage of AOTAutograd causes all the accumulate_grad_ nodes to get pushed to the end of
215
+ the graph. This differs from eager mode, which schedules them as soon as possible. This
216
+ pass attempts to reorder the graph to mimic eager behavior.
217
+ """
218
+ order : Dict [torch .fx .Node , int ] = {}
219
+ counter = itertools .count ()
220
+ target = torch .ops .inductor .accumulate_grad_ .default
221
+ last = None
222
+ for node in [* self .fx_tracer .graph .nodes ]:
223
+ if node .op == "call_function" and node .target == target :
224
+ arg = max (node .args , key = order .get ) # type: ignore[arg-type]
225
+ if arg is not last :
226
+ arg .append (node )
227
+ order [node ] = next (counter )
228
+ last = node
229
+
210
230
def to_proxy (self , t ):
211
231
if t is None :
212
232
return None
0 commit comments