@@ -72,7 +72,7 @@ class DecompositionInterpreter(Interpreter):
72
72
DecompositionInterpreter takes the high-level graph module, run the iternal nodes following the topo order, and decompose
73
73
high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way. Note
74
74
that certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific
75
- heuristic based parallelization strategy for them and we can conveniently replace them into their parallelized counterparts
75
+ heuristic based parallelization strategy for them so that we can conveniently replace them into their parallelized counterparts
76
76
in the orignal graph module.
77
77
78
78
Note that the traced graph is a low-level equivalent representation of the original graph module, and is only used for
@@ -106,7 +106,6 @@ def placeholder(self, target, args, kwargs):
106
106
track_tensor_tree (out , proxy , constant = None , tracer = self .tracer )
107
107
108
108
out = pytree .tree_map_only (torch .Tensor , lambda x : to_fun (x ), out )
109
- # TODO handle case where the first character of target is '*'
110
109
return out
111
110
112
111
def call_function (self , target , args , kwargs ):
@@ -187,9 +186,25 @@ def run(self, *args, **kwargs):
187
186
188
187
def decompose_and_functionalize (
189
188
graph_module : GraphModule ,
190
- decomposition_table : Dict = core_aten_decompositions (),
189
+ decomposition_table : Dict [ torch . _ops . OperatorBase , Callable ] = core_aten_decompositions (),
191
190
leaf_function_targets : List [Callable ] = [F .scaled_dot_product_attention ],
192
191
) -> Callable :
192
+ """
193
+ API to decompose and funcitonalize a high-level graph module.
194
+
195
+ Args:
196
+ graph_module (GraphModule):
197
+ The high-level graph module to be decomposed and functionalized.
198
+ decomposition_table (Dict[torch._ops.OperatorBase, Callable], defaults to `core_aten_decompostions()`):
199
+ The lookup table which maps high-level torch op to their equivalent low-level implementation.
200
+ leaf_function_targets (List[Callable], defaults to `[F.scaled_dot_product_attention]`):
201
+ Functions which will not be traced through for convenience, `F.scaled_dot_product_attention` is
202
+ treated as a leaf function by default so that we don't have to deal with all detailed version of
203
+ sdpas in the traced graph.
204
+
205
+ Returns:
206
+ Callable: a wrapper which returns the traced low-level graph when called with concrete arguments.
207
+ """
193
208
new_graph = Graph (owning_module = graph_module )
194
209
interp = DecompositionInterpreter (graph_module , new_graph , decomposition_table , leaf_function_targets )
195
210
0 commit comments