36
36
log = logging .getLogger (__name__ )
37
37
38
38
39
- __all__ = ["InterpreterModule" , "UnflattenedModule" , "unflatten" , "FlatArgsAdapter" ]
39
+ __all__ = [
40
+ "FlatArgsAdapter" ,
41
+ "InterpreterModule" ,
42
+ "InterpreterModuleDispatcher" ,
43
+ "UnflattenedModule" ,
44
+ "unflatten" ,
45
+ ]
40
46
41
47
42
48
class _AttrKind (Enum ):
@@ -195,6 +201,50 @@ def print_readable(
195
201
)
196
202
197
203
204
+ class InterpreterModuleDispatcher (torch .nn .Module ):
205
+ """
206
+ A module that carries a sequence of InterpreterModules corresponding to
207
+ a sequence of calls of that module. Each call to the module dispatches
208
+ to the next InterpreterModule, and wraps back around after the last.
209
+ """
210
+
211
+ def __init__ (self , call_modules : List [InterpreterModule ]):
212
+ super ().__init__ ()
213
+ assert call_modules
214
+ self ._call_modules = call_modules
215
+ self ._num_calls = 0
216
+
217
+ def forward (self , * args , ** kwargs ):
218
+ call_module = self ._call_modules [self ._num_calls ]
219
+ self ._num_calls = (self ._num_calls + 1 ) % len (self ._call_modules )
220
+ try :
221
+ return call_module (* args , ** kwargs )
222
+ except Exception :
223
+ self ._num_calls = 0
224
+ raise
225
+
226
+ def call_modules (self ):
227
+ return self ._call_modules
228
+
229
+ def print_readable (
230
+ self ,
231
+ print_output = True ,
232
+ include_stride = False ,
233
+ include_device = False ,
234
+ colored = False ,
235
+ ):
236
+ outputs = [
237
+ mod .print_readable (
238
+ print_output ,
239
+ include_stride ,
240
+ include_device ,
241
+ colored ,
242
+ )
243
+ for mod in self ._call_modules
244
+ ]
245
+ return "\n " .join (outputs )
246
+
247
+
198
248
class FlatArgsAdapter (abc .ABC ):
199
249
"""
200
250
Adapts input arguments with ``input_spec`` to align ``target_spec``.
@@ -415,7 +465,7 @@ def add_to_consts_map(obj_id, node_name, target_name):
415
465
inputs_to_state [n ] = targets
416
466
417
467
_sink_params (self , inputs_to_state , [])
418
- _deduplicate_modules (seen_modules .values ())
468
+ redirected_call_indices = _deduplicate_modules (seen_modules .values ())
419
469
420
470
# Helper function to check input nodes of `module` has been processed.
421
471
def check_module_inputs (module , scope ):
@@ -445,6 +495,7 @@ def check_module_inputs(module, scope):
445
495
446
496
# Recurively check all input nodes have been processed.
447
497
check_module_inputs (self , [])
498
+ self ._dispatch_modules (redirected_call_indices )
448
499
449
500
# Cache so we don't have to compute this every time.
450
501
# NOTE: this needs to be kept in sync with the placeholders in
@@ -541,6 +592,49 @@ def forward(self, *args, **kwargs):
541
592
)
542
593
return pytree .tree_unflatten (tree_out , signature .out_spec )
543
594
595
+ def _dispatch_modules (self , redirected_call_indices ):
596
+ """For a module whose call signatures are preserved, replace
597
+ multiple modules corresponding to multiple calls to that module
598
+ with a single dispatcher module that tracks which module to call.
599
+ """
600
+
601
+ # some modules were removed and their fqns redirected to other
602
+ # fqns during deduplication; make a consolidated fqn -> module map
603
+ all_modules = {}
604
+ for fqn , mod in self .named_modules (remove_duplicate = False ):
605
+ all_modules [fqn ] = mod
606
+ for fqn , fqn_ in redirected_call_indices .items ():
607
+ all_modules [fqn ] = all_modules [fqn_ ]
608
+
609
+ # for each fqn whose module call signature is preserved,
610
+ # map that fqn to a list of called modules
611
+ module_call_graph = {
612
+ entry .fqn
613
+ for entry in self .module_call_graph
614
+ if entry .fqn and entry .signature
615
+ }
616
+ called_modules = defaultdict (list )
617
+ for fqn , mod in sorted (all_modules .items ()):
618
+ if fqn in module_call_graph :
619
+ called_modules [fqn .split ("@" )[0 ]].append (mod )
620
+
621
+ # replace multiple call modules with a single dispatcher module
622
+ for orig_fqn , call_modules in called_modules .items ():
623
+ if len (call_modules ) > 1 :
624
+ for i , call_module in enumerate (call_modules ):
625
+ fqn = _call_name (orig_fqn , i + 1 )
626
+ if fqn not in redirected_call_indices :
627
+ self ._modules .pop (fqn )
628
+ self .set_submodule (orig_fqn , InterpreterModuleDispatcher (call_modules ))
629
+
630
+ # elide call indices in call modules because they are
631
+ # tracked automatically inside the dispatcher module
632
+ for node in self .graph .nodes :
633
+ if node .op == "call_module" :
634
+ fqn = node .target .split ("@" )[0 ]
635
+ if fqn in called_modules :
636
+ node .target = fqn
637
+
544
638
def print_readable (
545
639
self ,
546
640
print_output = True ,
@@ -1340,6 +1434,7 @@ def _copy_graph_attrs(
1340
1434
1341
1435
1342
1436
def _deduplicate_modules (partitions ):
1437
+ redirected_call_indices = {}
1343
1438
for shared_submodules in partitions :
1344
1439
for i , entry in enumerate (shared_submodules ):
1345
1440
child_fqn = _call_name (entry .fqn , entry .call_idx )
@@ -1364,6 +1459,7 @@ def _deduplicate_modules(partitions):
1364
1459
entry .parent_fqn , seen_child_fqn
1365
1460
)
1366
1461
entry .parent_call_module .target = seen_target # type: ignore[union-attr]
1462
+ redirected_call_indices [child_fqn ] = seen_child_fqn
1367
1463
break
1368
1464
elif not deduplicated :
1369
1465
# Case 2: The current module has a different fqn than the seen module.
@@ -1378,6 +1474,8 @@ def _deduplicate_modules(partitions):
1378
1474
entry .parent_module .set_submodule (target , seen .module )
1379
1475
deduplicated = True
1380
1476
1477
+ return redirected_call_indices
1478
+
1381
1479
1382
1480
def _sink_params (
1383
1481
module : torch .nn .Module ,
0 commit comments