Skip to content

Commit bf1befd

Browse files
Add Parallel Cross Entropy (#2017)
1 parent 2179d33 commit bf1befd

File tree

7 files changed

+280
-20
lines changed

7 files changed

+280
-20
lines changed

optimum/fx/parallelization/decomp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def run(self, *args, **kwargs):
197197
def decompose_and_functionalize(
198198
graph_module: GraphModule,
199199
decomposition_table: Dict[torch._ops.OperatorBase, Callable] = core_aten_decompositions(),
200-
leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention],
200+
leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention, F.cross_entropy],
201201
) -> Callable:
202202
"""
203203
API to decompose and functionalize a high-level graph module.

optimum/fx/parallelization/op_registry/op_handlers.py

+27-8
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.fx import Node
2020

2121
from ..core import Config
22-
from ..utils import is_activation, is_embedding, is_linear
22+
from ..utils import is_activation, is_cross_entropy, is_cross_entropy_parallel_compatible, is_embedding, is_linear
2323

2424

2525
class Registry:
@@ -334,7 +334,16 @@ def propagate(self) -> List[int]:
334334
ndim = arg.meta["val"].ndim
335335
slice_dim = (slice_dim + ndim) % ndim
336336
if slice_dim == axis:
337-
# slice on the parallel axis is not allowed
337+
# slice on the parallel axis is not allowed, except it's a nop
338+
start, stop, step = 0, arg.meta["val"].shape[axis], 1
339+
if len(self.node.args) > 2:
340+
start = self.node.args[2]
341+
elif len(self.node.args) > 3:
342+
stop = self.node.args[3]
343+
elif len(self.node.args) > 4:
344+
step = self.node.args[4]
345+
if start == 0 and stop >= arg.meta["val"].shape[axis] and step == 1:
346+
return [axis]
338347
return []
339348
return [axis]
340349

@@ -404,12 +413,12 @@ def propagate(self) -> List[int]:
404413
if self.node.op in ["placeholder", "get_attr"]:
405414
return [None]
406415
elif self.node.op == "output":
407-
for node in self.node.all_input_nodes:
408-
# TODO: allow parallelized nodes in output, and append comm ops in graph tp all-gather
409-
# parallelized output if intructed
410-
if self.extract_axis(node) is not None:
411-
return []
412-
return [None]
416+
# does not care about if output is being parallelized right now, because if the output is loss,
417+
# then it must be not parallelized as long as it comes from sharded cross entropy.
418+
# TODO: append all-gather comm ops before all parallelized output nodes if instructed.
419+
input_arg = self.node.all_input_nodes[0]
420+
axis = self.extract_axis(input_arg)
421+
return [axis]
413422
elif is_linear(self.node):
414423
input_arg = self.node.all_input_nodes[0]
415424
axis = self.extract_axis(input_arg)
@@ -438,6 +447,16 @@ def propagate(self) -> List[int]:
438447
return [1, None] if self.config.enable_sequence_parallel else [None]
439448
else:
440449
return []
450+
elif is_cross_entropy(self.node):
451+
logits = self.node.all_input_nodes[0]
452+
axis = self.extract_axis(logits)
453+
if axis is None or (
454+
is_cross_entropy_parallel_compatible(self.node) and axis == logits.meta["val"].ndim - 1
455+
):
456+
# for cross entropy, the input logits parallel axis can only be the last axis or None
457+
return [None]
458+
else:
459+
return []
441460
elif is_activation(self.node):
442461
return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate()
443462

optimum/fx/parallelization/parallel_layers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
# limitations under the License.
1515
from .embedding import VocabParallelEmbedding
1616
from .linear import ColumnParallelLinear, RowParallelLinear
17+
from .loss import VocabParallelCrossEntropyLoss, sharded_cross_entropy_wrapper_fn
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from functools import wraps
16+
from typing import Optional
17+
18+
import torch
19+
import torch.distributed as dist
20+
import torch.nn as nn
21+
22+
from ..core import ParallelExecutionCtx
23+
24+
25+
# Adapted from https://github.com/huggingface/nanotron/blob/main/src/nanotron/parallel/tensor_parallel/functional.py
26+
class _ShardedCrossEntropy(torch.autograd.Function):
27+
@staticmethod
28+
def forward(
29+
ctx,
30+
sharded_logits: torch.Tensor, # (batch_size, length, sharded_hidden_size)
31+
target: torch.Tensor, # (batch_size, length)
32+
group: dist.ProcessGroup,
33+
):
34+
# Maximum value along last dimension across all GPUs.
35+
logits_max = torch.max(sharded_logits, dim=-1)[0]
36+
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group)
37+
# Subtract the maximum value.
38+
sharded_logits = sharded_logits - logits_max.unsqueeze(dim=-1)
39+
40+
# Get the shard's indices
41+
sharded_hidden_size = sharded_logits.shape[-1]
42+
rank = dist.get_rank(group)
43+
start_index = rank * sharded_hidden_size
44+
end_index = start_index + sharded_hidden_size
45+
46+
# Create a mask of valid ids (1 means it needs to be masked).
47+
target_mask = (target < start_index) | (target >= end_index)
48+
masked_target = target.clone() - start_index
49+
masked_target[target_mask] = 0
50+
51+
# Get predicted-logits = logits[target].
52+
# For Simplicity, we convert logits to a 2-D tensor with size
53+
# [*, shard-size] and target to a 1-D tensor of size [*].
54+
logits_2d = sharded_logits.view(-1, sharded_hidden_size)
55+
masked_target_1d = masked_target.view(-1)
56+
arange_1d = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device)
57+
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
58+
if predicted_logits_1d.is_contiguous():
59+
predicted_logits_1d = predicted_logits_1d.clone()
60+
else:
61+
predicted_logits_1d = predicted_logits_1d.contiguous()
62+
predicted_logits = predicted_logits_1d.view_as(target)
63+
predicted_logits[target_mask] = 0.0
64+
# All reduce is needed to get the chunks from other GPUs.
65+
dist.all_reduce(predicted_logits, op=dist.ReduceOp.SUM, group=group)
66+
67+
# Sum of exponential of logits along vocab dimension across all GPUs.
68+
exp_logits = sharded_logits
69+
torch.exp(sharded_logits, out=exp_logits)
70+
sum_exp_logits = exp_logits.sum(dim=-1)
71+
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=group)
72+
73+
# Loss = log(sum(exp(logits))) - predicted-logit.
74+
loss = torch.log(sum_exp_logits) - predicted_logits
75+
76+
# Normalize and optionally smooth logits
77+
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
78+
79+
# Store softmax, target-mask and masked-target for backward pass.
80+
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
81+
82+
return loss
83+
84+
@staticmethod
85+
def backward(ctx, grad_output: torch.Tensor):
86+
# Retrieve tensors from the forward path.
87+
softmax, target_mask, masked_target_1d = ctx.saved_tensors
88+
89+
# All the inputs have softmax as their gradient.
90+
grad_input = softmax
91+
# For simplicity, work with the 2D gradient.
92+
sharded_hidden_size = softmax.size()[-1]
93+
grad_2d = grad_input.view(-1, sharded_hidden_size)
94+
95+
# Add the gradient from matching classes.
96+
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
97+
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()
98+
99+
# Finally elementwise multiplication with the output gradients.
100+
grad_input.mul_(grad_output.unsqueeze(dim=-1))
101+
102+
return grad_input, None, None
103+
104+
105+
def sharded_cross_entropy(sharded_logits: torch.Tensor, target: torch.Tensor, process_group: dist.ProcessGroup):
106+
return _ShardedCrossEntropy.apply(sharded_logits, target, process_group)
107+
108+
109+
def sharded_cross_entropy_wrapper_fn(process_group: dist.ProcessGroup):
110+
@wraps(sharded_cross_entropy)
111+
def wrapper(
112+
sharded_logits: torch.Tensor,
113+
target: torch.Tensor,
114+
weight: Optional[torch.Tensor] = None,
115+
size_average: Optional[bool] = None,
116+
ignore_index: int = -100,
117+
reduce: Optional[bool] = None,
118+
reduction: str = "mean",
119+
label_smoothing: float = 0.0,
120+
):
121+
if weight is not None or ignore_index != -100 or label_smoothing != 0.0:
122+
raise ValueError(
123+
"Does not support weighted mode, index ignoring and label smoothing in current parallel cross entropy implementation."
124+
)
125+
loss: torch.Tensor = sharded_cross_entropy(sharded_logits, target, process_group)
126+
127+
if size_average is not None or reduce is not None:
128+
size_average = True if size_average is None else size_average
129+
reduce = True if reduce is None else reduce
130+
131+
if size_average and reduce:
132+
reduction = "mean"
133+
elif reduce:
134+
reduction = "sum"
135+
else:
136+
reduction = "none"
137+
138+
if reduction == "mean":
139+
return loss.mean()
140+
elif reduction == "sum":
141+
return loss.sum()
142+
return loss
143+
144+
return wrapper
145+
146+
147+
class VocabParallelCrossEntropyLoss(nn.Module):
148+
"""
149+
Simple parallel cross entropy implementation which does not support weighted mode and label smoothing yet.
150+
"""
151+
152+
def __init__(self, ctx: ParallelExecutionCtx, reduction: str = "mean") -> None:
153+
super(VocabParallelCrossEntropyLoss, self).__init__()
154+
self.process_group = ctx.tp_group
155+
self.reduction = reduction
156+
157+
def forward(self, sharded_logits: torch.Tensor, target: torch.Tensor):
158+
loss: torch.Tensor = _ShardedCrossEntropy.apply(sharded_logits, target, self.process_group)
159+
if self.reduction == "mean":
160+
return loss.mean()
161+
elif self.reduction == "sum":
162+
return loss.sum()
163+
return loss

optimum/fx/parallelization/passes.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,15 @@
2626
from .decomp import decompose_and_functionalize
2727
from .distributed import scatter
2828
from .op_registry import REGISTRY, FallbackParallelAxisPropagateHandler
29-
from .parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
29+
from .parallel_layers import (
30+
ColumnParallelLinear,
31+
RowParallelLinear,
32+
VocabParallelCrossEntropyLoss,
33+
VocabParallelEmbedding,
34+
sharded_cross_entropy_wrapper_fn,
35+
)
3036
from .utils import (
37+
is_cross_entropy,
3138
is_embedding,
3239
is_linear,
3340
is_shape_consumer,
@@ -273,6 +280,11 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
273280
info["sequence_parallel"] = False
274281
self.place_marker_per_node(node, info)
275282

283+
elif is_cross_entropy(node):
284+
axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis")
285+
if axis_before is not None:
286+
self.place_marker_per_node(node, {"axis": "vocab"})
287+
276288
return graph_module
277289

278290

@@ -343,6 +355,35 @@ def handle_embedding(node: Node, ctx: ParallelExecutionCtx) -> None:
343355
layer_cache[key] = new_mod
344356
setattr(parent_mod, field, new_mod)
345357

358+
@staticmethod
359+
def handle_cross_entropy(node: Node, ctx: ParallelExecutionCtx) -> None:
360+
axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis")
361+
if axis is None:
362+
return
363+
364+
assert axis in {"vocab"}, "Only support parallelization on vocab dim for now."
365+
if node.op == "call_module":
366+
graph_module = node.graph.owning_module
367+
prefix_and_field = node.target.rsplit(".", maxsplit=1)
368+
if len(prefix_and_field) == 2:
369+
parent_mod = graph_module.get_submodule(prefix_and_field[0])
370+
field = prefix_and_field[1]
371+
else:
372+
parent_mod = graph_module
373+
field = node.target
374+
375+
mod: nn.CrossEntropyLoss = graph_module.get_submodule(node.target)
376+
key, layer_cache = node.target, ctx.parallel_layer_cache
377+
if key in layer_cache:
378+
new_mod = layer_cache[key]
379+
else:
380+
assert ctx.compile_times == 0, "illegal path for recompilation"
381+
new_mod = VocabParallelCrossEntropyLoss(ctx, reduction=mod.reduction)
382+
layer_cache[key] = new_mod
383+
setattr(parent_mod, field, new_mod)
384+
else:
385+
node.target = sharded_cross_entropy_wrapper_fn(process_group=ctx.tp_group)
386+
346387
@staticmethod
347388
def handle_hard_coded_axis_param(node: Node, ctx: ParallelExecutionCtx) -> None:
348389
def extract_shape_from_node(node: Node) -> List[Any]:
@@ -384,6 +425,8 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
384425
self.handle_linear(node, ctx)
385426
elif is_embedding(node):
386427
self.handle_embedding(node, ctx)
428+
elif is_cross_entropy(node):
429+
self.handle_cross_entropy(node, ctx)
387430
# correct the attention head num in parallel setting
388431
elif is_shape_consumer(node):
389432
self.handle_hard_coded_axis_param(node, ctx)

optimum/fx/parallelization/utils.py

+34
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,40 @@ def is_shape_generator(node: Node) -> bool:
8282
return node.op == "call_method" and node.target == "size"
8383

8484

85+
def is_cross_entropy(node: Node) -> bool:
86+
if node.op == "call_function":
87+
return node.target is F.cross_entropy
88+
elif node.op == "call_module":
89+
mod = node.graph.owning_module
90+
return isinstance(mod.get_submodule(node.target), nn.CrossEntropyLoss)
91+
return False
92+
93+
94+
def is_cross_entropy_parallel_compatible(node: Node) -> bool:
95+
"""
96+
For now `VocabParallelCrossEntropyLoss` does not support weighted mode, index ignoring and label smoothing.
97+
"""
98+
if node.op == "call_function":
99+
weight = node.kwargs.get("weight", None)
100+
ignore_index = node.kwargs.get("ignore_index", -100)
101+
label_smoothing = node.kwargs.get("label_smoothing", 0.0)
102+
if len(node.args) > 2 and weight is None:
103+
weight = node.args[2]
104+
if len(node.args) > 4 and ignore_index == -100:
105+
ignore_index = node.args[4]
106+
if len(node.args) > 7 and label_smoothing == 0.0:
107+
label_smoothing = node.args[7]
108+
109+
return weight is None and ignore_index == -100 and label_smoothing == 0.0
110+
111+
elif node.op == "call_module":
112+
mod: nn.CrossEntropyLoss = node.graph.owning_module.get_submodule(node.target)
113+
weight, label_smoothing, ignore_index = mod.weight, mod.label_smoothing, mod.ignore_index
114+
return weight is None and ignore_index == -100 and label_smoothing == 0.0
115+
116+
return False
117+
118+
85119
def stable_topological_sort(graph: Graph):
86120
def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]:
87121
args: List[torch.fx.node.Argument] = []

0 commit comments

Comments
 (0)