Skip to content

Commit 252c3b7

Browse files
more comments
1 parent c689402 commit 252c3b7

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

optimum/fx/parallelization/api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import importlib
1616
import os
1717
from functools import partial
18-
from typing import List
18+
from typing import Callable, List
1919

2020
import torch
2121
from torch.fx import GraphModule
@@ -48,7 +48,7 @@ def parallelize_model(
4848
parallel_ctx: ParallelExecutionCtx,
4949
*model_args,
5050
**kwargs,
51-
):
51+
) -> Callable:
5252
"""
5353
API for automatic model parallelism through Pytorch FX.
5454

optimum/fx/parallelization/decomp.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class DecompositionInterpreter(Interpreter):
7272
DecompositionInterpreter takes the high-level graph module, run the iternal nodes following the topo order, and decompose
7373
high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way. Note
7474
that certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific
75-
heuristic based parallelization strategy for them and we can conveniently replace them into their parallelized counterparts
75+
heuristic based parallelization strategy for them so that we can conveniently replace them into their parallelized counterparts
7676
in the orignal graph module.
7777
7878
Note that the traced graph is a low-level equivalent representation of the original graph module, and is only used for
@@ -106,7 +106,6 @@ def placeholder(self, target, args, kwargs):
106106
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
107107

108108
out = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), out)
109-
# TODO handle case where the first character of target is '*'
110109
return out
111110

112111
def call_function(self, target, args, kwargs):
@@ -187,9 +186,25 @@ def run(self, *args, **kwargs):
187186

188187
def decompose_and_functionalize(
189188
graph_module: GraphModule,
190-
decomposition_table: Dict = core_aten_decompositions(),
189+
decomposition_table: Dict[torch._ops.OperatorBase, Callable] = core_aten_decompositions(),
191190
leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention],
192191
) -> Callable:
192+
"""
193+
API to decompose and funcitonalize a high-level graph module.
194+
195+
Args:
196+
graph_module (GraphModule):
197+
The high-level graph module to be decomposed and functionalized.
198+
decomposition_table (Dict[torch._ops.OperatorBase, Callable], defaults to `core_aten_decompostions()`):
199+
The lookup table which maps high-level torch op to their equivalent low-level implementation.
200+
leaf_function_targets (List[Callable], defaults to `[F.scaled_dot_product_attention]`):
201+
Functions which will not be traced through for convenience, `F.scaled_dot_product_attention` is
202+
treated as a leaf function by default so that we don't have to deal with all detailed version of
203+
sdpas in the traced graph.
204+
205+
Returns:
206+
Callable: a wrapper which returns the traced low-level graph when called with concrete arguments.
207+
"""
193208
new_graph = Graph(owning_module=graph_module)
194209
interp = DecompositionInterpreter(graph_module, new_graph, decomposition_table, leaf_function_targets)
195210

optimum/fx/parallelization/passes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ class ParallelAxisSolverPass(AnalyzeBase):
170170
- Optimal Solution. Note that since we return the first solution we find, then it might not be optimal in terms of
171171
memory consumption and communication overhead. But again we can adjust the order of search and try parallelize
172172
as much as we can first before fall back to non-parallelized search paths. And we don't pay too much attention
173-
on calculating communication overhead because in practice they are bounded by number of certain layers in the graph
174-
under the constraint that only certain layers are allowed to communicate.
173+
on calculating communication overhead because in practice they are bounded under the constraint that only certain
174+
layers are allowed to communicate.
175175
176176
Our goal is not to solve an optimization problem which tries to give a best solution of parallelizing any model under memory/hardware
177177
constraints, but rather a cheap solution which relieves you from writing boilerplate code for parallelizing layers of different models.

0 commit comments

Comments
 (0)