Skip to content

Commit 79316ee

Browse files
format
1 parent f0c8a6b commit 79316ee

File tree

4 files changed

+27
-22
lines changed

4 files changed

+27
-22
lines changed

optimum/fx/parallelization/op_registry/op_handlers.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.fx import Node
2020

2121
from ..core import Config
22-
from ..utils import is_activation, is_embedding, is_linear, is_cross_entropy, is_cross_entropy_parallel_compatible
22+
from ..utils import is_activation, is_cross_entropy, is_cross_entropy_parallel_compatible, is_embedding, is_linear
2323

2424

2525
class Registry:
@@ -450,7 +450,9 @@ def propagate(self) -> List[int]:
450450
elif is_cross_entropy(self.node):
451451
logits = self.node.all_input_nodes[0]
452452
axis = self.extract_axis(logits)
453-
if axis is None or (is_cross_entropy_parallel_compatible(self.node) and axis == logits.meta['val'].ndim - 1):
453+
if axis is None or (
454+
is_cross_entropy_parallel_compatible(self.node) and axis == logits.meta["val"].ndim - 1
455+
):
454456
# for cross entropy, the input logits parallel axis can only be the last axis or None
455457
return [None]
456458
else:

optimum/fx/parallelization/parallel_layers/loss.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from functools import wraps
16+
from typing import Optional
17+
1518
import torch
16-
import torch.nn as nn
1719
import torch.distributed as dist
18-
from typing import Optional
19-
from functools import wraps
20+
import torch.nn as nn
21+
2022
from ..core import ParallelExecutionCtx
2123

2224

@@ -100,7 +102,7 @@ def backward(ctx, grad_output: torch.Tensor):
100102
return grad_input, None, None
101103

102104

103-
def sharded_cross_entropy(sharded_logits: torch.Tensor, target: torch.Tensor,process_group: dist.ProcessGroup):
105+
def sharded_cross_entropy(sharded_logits: torch.Tensor, target: torch.Tensor, process_group: dist.ProcessGroup):
104106
return _ShardedCrossEntropy.apply(sharded_logits, target, process_group)
105107

106108

@@ -127,15 +129,15 @@ def wrapper(
127129
reduce = True if reduce is None else reduce
128130

129131
if size_average and reduce:
130-
reduction = 'mean'
132+
reduction = "mean"
131133
elif reduce:
132-
reduction = 'sum'
134+
reduction = "sum"
133135
else:
134-
reduction = 'none'
136+
reduction = "none"
135137

136-
if reduction == 'mean':
138+
if reduction == "mean":
137139
return loss.mean()
138-
elif reduction == 'sum':
140+
elif reduction == "sum":
139141
return loss.sum()
140142
return loss
141143

@@ -146,15 +148,16 @@ class VocabParallelCrossEntropyLoss(nn.Module):
146148
"""
147149
Simple parallel cross entropy implementation which does not support weighted mode and label smoothing yet.
148150
"""
149-
def __init__(self, ctx: ParallelExecutionCtx, reduction: str = 'mean') -> None:
151+
152+
def __init__(self, ctx: ParallelExecutionCtx, reduction: str = "mean") -> None:
150153
super(VocabParallelCrossEntropyLoss, self).__init__()
151154
self.process_group = ctx.tp_group
152155
self.reduction = reduction
153156

154157
def forward(self, sharded_logits: torch.Tensor, target: torch.Tensor):
155158
loss: torch.Tensor = _ShardedCrossEntropy.apply(sharded_logits, target, self.process_group)
156-
if self.reduction == 'mean':
159+
if self.reduction == "mean":
157160
return loss.mean()
158-
elif self.reduction == 'sum':
161+
elif self.reduction == "sum":
159162
return loss.sum()
160163
return loss

optimum/fx/parallelization/passes.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,23 @@
2121
import torch.distributed as dist
2222
import torch.nn as nn
2323
from torch.fx import Graph, GraphModule, Node
24+
2425
from .core import Config, ParallelExecutionCtx, ParameterMeta
2526
from .decomp import decompose_and_functionalize
2627
from .distributed import scatter
2728
from .op_registry import REGISTRY, FallbackParallelAxisPropagateHandler
2829
from .parallel_layers import (
2930
ColumnParallelLinear,
3031
RowParallelLinear,
31-
VocabParallelEmbedding,
3232
VocabParallelCrossEntropyLoss,
33-
sharded_cross_entropy_wrapper_fn
33+
VocabParallelEmbedding,
34+
sharded_cross_entropy_wrapper_fn,
3435
)
3536
from .utils import (
37+
is_cross_entropy,
3638
is_embedding,
3739
is_linear,
3840
is_shape_consumer,
39-
is_cross_entropy,
4041
stable_topological_sort,
4142
)
4243

@@ -282,7 +283,7 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
282283
elif is_cross_entropy(node):
283284
axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis")
284285
if axis_before is not None:
285-
self.place_marker_per_node(node, {'axis' : 'vocab'})
286+
self.place_marker_per_node(node, {"axis": "vocab"})
286287

287288
return graph_module
288289

@@ -383,7 +384,6 @@ def handle_cross_entropy(node: Node, ctx: ParallelExecutionCtx) -> None:
383384
else:
384385
node.target = sharded_cross_entropy_wrapper_fn(process_group=ctx.tp_group)
385386

386-
387387
@staticmethod
388388
def handle_hard_coded_axis_param(node: Node, ctx: ParallelExecutionCtx) -> None:
389389
def extract_shape_from_node(node: Node) -> List[Any]:

optimum/fx/parallelization/utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def is_cross_entropy_parallel_compatible(node: Node) -> bool:
9696
For now `VocabParallelCrossEntropyLoss` does not support weighted mode, index ignoring and label smoothing.
9797
"""
9898
if node.op == "call_function":
99-
weight = node.kwargs.get('weight', None)
100-
ignore_index = node.kwargs.get('ignore_index', -100)
101-
label_smoothing = node.kwargs.get('label_smoothing', 0.0)
99+
weight = node.kwargs.get("weight", None)
100+
ignore_index = node.kwargs.get("ignore_index", -100)
101+
label_smoothing = node.kwargs.get("label_smoothing", 0.0)
102102
if len(node.args) > 2 and weight is None:
103103
weight = node.args[2]
104104
if len(node.args) > 4 and ignore_index == -100:

0 commit comments

Comments
 (0)