Skip to content

Commit c689402

Browse files
more comments
1 parent 4114d3b commit c689402

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

optimum/fx/parallelization/api.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def parallelize_model(
5454
5555
Args:
5656
model (str):
57-
Model to parallelize, a model id on the Huggingface Hub.
57+
Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights
58+
of the model.
5859
parallel_ctx (ParallelExecutionCtx):
5960
Parallel execution context containing process groups the current process belongs to.
6061
*model_args (Any):

optimum/fx/parallelization/decomp.py

+11
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,17 @@ def __init__(self, graph: Graph):
6868

6969

7070
class DecompositionInterpreter(Interpreter):
71+
"""
72+
DecompositionInterpreter takes the high-level graph module, run the iternal nodes following the topo order, and decompose
73+
high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way. Note
74+
that certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific
75+
heuristic based parallelization strategy for them and we can conveniently replace them into their parallelized counterparts
76+
in the orignal graph module.
77+
78+
Note that the traced graph is a low-level equivalent representation of the original graph module, and is only used for
79+
parallel axis propagation and analysis, the original graph module is still used for real execution.
80+
"""
81+
7182
def __init__(
7283
self, module: GraphModule, new_graph: Graph, decomposition_table=None, leaf_function_targets=None, **kwargs
7384
):

0 commit comments

Comments
 (0)