Skip to content

Commit 7691064

Browse files
avikchaudhuripytorchmergebot
authored andcommitted
dispatcher module for multiple graphs (pytorch#139439)
Differential Revision: [D65307961](https://our.internmc.facebook.com/intern/diff/D65307961/) This PR introduces the concept of a "dispatcher" module `n` that carries multiple interpreter modules `n`, `n@1`, `n@2`, etc., each corresponding to a particular call of `n` and thus might carry a different specialized graph. We only do this when we're preserving module call signatures for `n`. The carried modules have the same number and order of calls to `n` appearing in the original module / exported program. In the unflattened module, all those calls go to the "dispatcher" module which internally tracks how many calls have been made so far and invokes the corresponding interpreter module. We reset this tracking after a successful or unsuccessful run of the unflattened module. Overall this makes swapping easier when module call signatures are preserved. Pull Request resolved: pytorch#139439 Approved by: https://github.com/tugsbayasgalan ghstack dependencies: pytorch#139438
1 parent 9a5175e commit 7691064

File tree

2 files changed

+175
-9
lines changed

2 files changed

+175
-9
lines changed

test/export/test_export.py

+75-7
Original file line numberDiff line numberDiff line change
@@ -6767,7 +6767,7 @@ def test(ep, swap):
67676767
if not is_retracebility_test(self._testMethodName):
67686768
test(
67696769
export(M(), inp, preserve_module_call_signature=("n",)),
6770-
swap={"n": N(), "n@1": N()},
6770+
swap={"n": N()},
67716771
)
67726772

67736773
class _N(torch.nn.Module):
@@ -6820,7 +6820,7 @@ def forward(self, x):
68206820
unflattened_result = ufm(*inp)
68216821
self.assertTrue(torch.allclose(unflattened_result, eager_result))
68226822

6823-
def test_unflatten_multiple_graphs_preserve_signature_no_error(self):
6823+
def test_unflatten_multiple_graphs_dispatch(self):
68246824
class N(torch.nn.Module):
68256825
def forward(self, x, b):
68266826
if b:
@@ -6837,8 +6837,10 @@ def forward(self, x):
68376837
x = x + 3
68386838
x = self.n(x, True)
68396839
x = x + 4
6840-
x = self.n(x, False)
6840+
x = self.n(x, True)
68416841
x = x + 5
6842+
x = self.n(x, False)
6843+
x = x + 6
68426844
return x
68436845

68446846
inp = (torch.ones(1),)
@@ -6856,8 +6858,65 @@ def test(ep):
68566858
self.assertTrue(torch.allclose(unflattened_result, eager_result))
68576859

68586860
if not is_retracebility_test(self._testMethodName):
6861+
if is_training_ir_test(self._testMethodName):
6862+
test(
6863+
torch.export.export_for_training(
6864+
M(),
6865+
inp,
6866+
strict=not is_non_strict_test(self._testMethodName),
6867+
preserve_module_call_signature=("n",),
6868+
)
6869+
)
6870+
68596871
test(export(M(), inp, preserve_module_call_signature=("n",)))
68606872

6873+
def test_unflatten_multiple_graphs_preserve_signature_no_error(self):
6874+
class N(torch.nn.Module):
6875+
def forward(self, x, b):
6876+
if b:
6877+
return x + 1
6878+
else:
6879+
return x + 2
6880+
6881+
class M(torch.nn.Module):
6882+
def __init__(self):
6883+
super().__init__()
6884+
self.n = N()
6885+
6886+
def forward(self, x):
6887+
x = x + 3
6888+
x = self.n(x, True)
6889+
x = x + 4
6890+
x = self.n(x, False)
6891+
x = x + 5
6892+
return x
6893+
6894+
inp = (torch.ones(1),)
6895+
m = M()
6896+
eager_result = m(*inp)
6897+
6898+
def test(ep, swap=None):
6899+
epm = ep.module()
6900+
ufm = torch.export.unflatten(ep)
6901+
6902+
exported_result = epm(*inp)
6903+
self.assertTrue(torch.allclose(exported_result, eager_result))
6904+
6905+
unflattened_result = ufm(*inp)
6906+
self.assertTrue(torch.allclose(unflattened_result, eager_result))
6907+
6908+
if swap:
6909+
for fqn, mod in swap.items():
6910+
ufm.set_submodule(fqn, mod)
6911+
unflattened_result = ufm(*inp)
6912+
self.assertTrue(torch.allclose(unflattened_result, eager_result))
6913+
6914+
if not is_retracebility_test(self._testMethodName):
6915+
test(
6916+
export(M(), inp, preserve_module_call_signature=("n",)),
6917+
swap={"n": N()},
6918+
)
6919+
68616920
test(export(M(), inp))
68626921

68636922
@testing.expectedFailureRetraceabilityNonStrict
@@ -6893,7 +6952,7 @@ def forward(self, x):
68936952
m = M()
68946953
eager_result = m(*inp)
68956954

6896-
def test(ep):
6955+
def test(ep, swap=None):
68976956
epm = ep.module()
68986957
ufm = torch.export.unflatten(ep)
68996958

@@ -6903,11 +6962,20 @@ def test(ep):
69036962
unflattened_result = ufm(*inp)
69046963
self.assertTrue(torch.allclose(unflattened_result, eager_result))
69056964

6965+
if swap:
6966+
for fqn, mod in swap.items():
6967+
ufm.set_submodule(fqn, mod)
6968+
unflattened_result = ufm(*inp)
6969+
self.assertTrue(torch.allclose(unflattened_result, eager_result))
6970+
69066971
if not is_retracebility_test(self._testMethodName):
6907-
test(export(M(), inp, preserve_module_call_signature=("n",)))
6972+
test(
6973+
export(M(), inp, preserve_module_call_signature=("n",)),
6974+
swap={"n": N()},
6975+
)
69086976
# running decompositions again should work for all IRs
69096977
ep = export(M(), inp, preserve_module_call_signature=("n",))
6910-
test(ep.run_decompositions({}))
6978+
test(ep.run_decompositions({}), swap={"n": N()})
69116979
if is_training_ir_test(self._testMethodName):
69126980
# since we run decompositions by default when testing training IR,
69136981
# also test training IR without running decompositions
@@ -6918,7 +6986,7 @@ def test(ep):
69186986
strict=strict,
69196987
preserve_module_call_signature=("n",),
69206988
)
6921-
test(ept)
6989+
test(ept, swap={"n": N()})
69226990

69236991
test(export(M(), inp))
69246992

torch/export/unflatten.py

+100-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@
3636
log = logging.getLogger(__name__)
3737

3838

39-
__all__ = ["InterpreterModule", "UnflattenedModule", "unflatten", "FlatArgsAdapter"]
39+
__all__ = [
40+
"FlatArgsAdapter",
41+
"InterpreterModule",
42+
"InterpreterModuleDispatcher",
43+
"UnflattenedModule",
44+
"unflatten",
45+
]
4046

4147

4248
class _AttrKind(Enum):
@@ -195,6 +201,50 @@ def print_readable(
195201
)
196202

197203

204+
class InterpreterModuleDispatcher(torch.nn.Module):
205+
"""
206+
A module that carries a sequence of InterpreterModules corresponding to
207+
a sequence of calls of that module. Each call to the module dispatches
208+
to the next InterpreterModule, and wraps back around after the last.
209+
"""
210+
211+
def __init__(self, call_modules: List[InterpreterModule]):
212+
super().__init__()
213+
assert call_modules
214+
self._call_modules = call_modules
215+
self._num_calls = 0
216+
217+
def forward(self, *args, **kwargs):
218+
call_module = self._call_modules[self._num_calls]
219+
self._num_calls = (self._num_calls + 1) % len(self._call_modules)
220+
try:
221+
return call_module(*args, **kwargs)
222+
except Exception:
223+
self._num_calls = 0
224+
raise
225+
226+
def call_modules(self):
227+
return self._call_modules
228+
229+
def print_readable(
230+
self,
231+
print_output=True,
232+
include_stride=False,
233+
include_device=False,
234+
colored=False,
235+
):
236+
outputs = [
237+
mod.print_readable(
238+
print_output,
239+
include_stride,
240+
include_device,
241+
colored,
242+
)
243+
for mod in self._call_modules
244+
]
245+
return "\n".join(outputs)
246+
247+
198248
class FlatArgsAdapter(abc.ABC):
199249
"""
200250
Adapts input arguments with ``input_spec`` to align ``target_spec``.
@@ -415,7 +465,7 @@ def add_to_consts_map(obj_id, node_name, target_name):
415465
inputs_to_state[n] = targets
416466

417467
_sink_params(self, inputs_to_state, [])
418-
_deduplicate_modules(seen_modules.values())
468+
redirected_call_indices = _deduplicate_modules(seen_modules.values())
419469

420470
# Helper function to check input nodes of `module` has been processed.
421471
def check_module_inputs(module, scope):
@@ -445,6 +495,7 @@ def check_module_inputs(module, scope):
445495

446496
# Recurively check all input nodes have been processed.
447497
check_module_inputs(self, [])
498+
self._dispatch_modules(redirected_call_indices)
448499

449500
# Cache so we don't have to compute this every time.
450501
# NOTE: this needs to be kept in sync with the placeholders in
@@ -541,6 +592,49 @@ def forward(self, *args, **kwargs):
541592
)
542593
return pytree.tree_unflatten(tree_out, signature.out_spec)
543594

595+
def _dispatch_modules(self, redirected_call_indices):
596+
"""For a module whose call signatures are preserved, replace
597+
multiple modules corresponding to multiple calls to that module
598+
with a single dispatcher module that tracks which module to call.
599+
"""
600+
601+
# some modules were removed and their fqns redirected to other
602+
# fqns during deduplication; make a consolidated fqn -> module map
603+
all_modules = {}
604+
for fqn, mod in self.named_modules(remove_duplicate=False):
605+
all_modules[fqn] = mod
606+
for fqn, fqn_ in redirected_call_indices.items():
607+
all_modules[fqn] = all_modules[fqn_]
608+
609+
# for each fqn whose module call signature is preserved,
610+
# map that fqn to a list of called modules
611+
module_call_graph = {
612+
entry.fqn
613+
for entry in self.module_call_graph
614+
if entry.fqn and entry.signature
615+
}
616+
called_modules = defaultdict(list)
617+
for fqn, mod in sorted(all_modules.items()):
618+
if fqn in module_call_graph:
619+
called_modules[fqn.split("@")[0]].append(mod)
620+
621+
# replace multiple call modules with a single dispatcher module
622+
for orig_fqn, call_modules in called_modules.items():
623+
if len(call_modules) > 1:
624+
for i, call_module in enumerate(call_modules):
625+
fqn = _call_name(orig_fqn, i + 1)
626+
if fqn not in redirected_call_indices:
627+
self._modules.pop(fqn)
628+
self.set_submodule(orig_fqn, InterpreterModuleDispatcher(call_modules))
629+
630+
# elide call indices in call modules because they are
631+
# tracked automatically inside the dispatcher module
632+
for node in self.graph.nodes:
633+
if node.op == "call_module":
634+
fqn = node.target.split("@")[0]
635+
if fqn in called_modules:
636+
node.target = fqn
637+
544638
def print_readable(
545639
self,
546640
print_output=True,
@@ -1340,6 +1434,7 @@ def _copy_graph_attrs(
13401434

13411435

13421436
def _deduplicate_modules(partitions):
1437+
redirected_call_indices = {}
13431438
for shared_submodules in partitions:
13441439
for i, entry in enumerate(shared_submodules):
13451440
child_fqn = _call_name(entry.fqn, entry.call_idx)
@@ -1364,6 +1459,7 @@ def _deduplicate_modules(partitions):
13641459
entry.parent_fqn, seen_child_fqn
13651460
)
13661461
entry.parent_call_module.target = seen_target # type: ignore[union-attr]
1462+
redirected_call_indices[child_fqn] = seen_child_fqn
13671463
break
13681464
elif not deduplicated:
13691465
# Case 2: The current module has a different fqn than the seen module.
@@ -1378,6 +1474,8 @@ def _deduplicate_modules(partitions):
13781474
entry.parent_module.set_submodule(target, seen.module)
13791475
deduplicated = True
13801476

1477+
return redirected_call_indices
1478+
13811479

13821480
def _sink_params(
13831481
module: torch.nn.Module,

0 commit comments

Comments
 (0)