|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2024 The HuggingFace Team. All rights reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +from functools import wraps |
| 16 | +from typing import Optional |
| 17 | + |
| 18 | +import torch |
| 19 | +import torch.distributed as dist |
| 20 | +import torch.nn as nn |
| 21 | + |
| 22 | +from ..core import ParallelExecutionCtx |
| 23 | + |
| 24 | + |
| 25 | +# Adapted from https://github.com/huggingface/nanotron/blob/main/src/nanotron/parallel/tensor_parallel/functional.py |
| 26 | +class _ShardedCrossEntropy(torch.autograd.Function): |
| 27 | + @staticmethod |
| 28 | + def forward( |
| 29 | + ctx, |
| 30 | + sharded_logits: torch.Tensor, # (batch_size, length, sharded_hidden_size) |
| 31 | + target: torch.Tensor, # (batch_size, length) |
| 32 | + group: dist.ProcessGroup, |
| 33 | + ): |
| 34 | + # Maximum value along last dimension across all GPUs. |
| 35 | + logits_max = torch.max(sharded_logits, dim=-1)[0] |
| 36 | + dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group) |
| 37 | + # Subtract the maximum value. |
| 38 | + sharded_logits = sharded_logits - logits_max.unsqueeze(dim=-1) |
| 39 | + |
| 40 | + # Get the shard's indices |
| 41 | + sharded_hidden_size = sharded_logits.shape[-1] |
| 42 | + rank = dist.get_rank(group) |
| 43 | + start_index = rank * sharded_hidden_size |
| 44 | + end_index = start_index + sharded_hidden_size |
| 45 | + |
| 46 | + # Create a mask of valid ids (1 means it needs to be masked). |
| 47 | + target_mask = (target < start_index) | (target >= end_index) |
| 48 | + masked_target = target.clone() - start_index |
| 49 | + masked_target[target_mask] = 0 |
| 50 | + |
| 51 | + # Get predicted-logits = logits[target]. |
| 52 | + # For Simplicity, we convert logits to a 2-D tensor with size |
| 53 | + # [*, shard-size] and target to a 1-D tensor of size [*]. |
| 54 | + logits_2d = sharded_logits.view(-1, sharded_hidden_size) |
| 55 | + masked_target_1d = masked_target.view(-1) |
| 56 | + arange_1d = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device) |
| 57 | + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] |
| 58 | + if predicted_logits_1d.is_contiguous(): |
| 59 | + predicted_logits_1d = predicted_logits_1d.clone() |
| 60 | + else: |
| 61 | + predicted_logits_1d = predicted_logits_1d.contiguous() |
| 62 | + predicted_logits = predicted_logits_1d.view_as(target) |
| 63 | + predicted_logits[target_mask] = 0.0 |
| 64 | + # All reduce is needed to get the chunks from other GPUs. |
| 65 | + dist.all_reduce(predicted_logits, op=dist.ReduceOp.SUM, group=group) |
| 66 | + |
| 67 | + # Sum of exponential of logits along vocab dimension across all GPUs. |
| 68 | + exp_logits = sharded_logits |
| 69 | + torch.exp(sharded_logits, out=exp_logits) |
| 70 | + sum_exp_logits = exp_logits.sum(dim=-1) |
| 71 | + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=group) |
| 72 | + |
| 73 | + # Loss = log(sum(exp(logits))) - predicted-logit. |
| 74 | + loss = torch.log(sum_exp_logits) - predicted_logits |
| 75 | + |
| 76 | + # Normalize and optionally smooth logits |
| 77 | + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) |
| 78 | + |
| 79 | + # Store softmax, target-mask and masked-target for backward pass. |
| 80 | + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) |
| 81 | + |
| 82 | + return loss |
| 83 | + |
| 84 | + @staticmethod |
| 85 | + def backward(ctx, grad_output: torch.Tensor): |
| 86 | + # Retrieve tensors from the forward path. |
| 87 | + softmax, target_mask, masked_target_1d = ctx.saved_tensors |
| 88 | + |
| 89 | + # All the inputs have softmax as their gradient. |
| 90 | + grad_input = softmax |
| 91 | + # For simplicity, work with the 2D gradient. |
| 92 | + sharded_hidden_size = softmax.size()[-1] |
| 93 | + grad_2d = grad_input.view(-1, sharded_hidden_size) |
| 94 | + |
| 95 | + # Add the gradient from matching classes. |
| 96 | + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) |
| 97 | + grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() |
| 98 | + |
| 99 | + # Finally elementwise multiplication with the output gradients. |
| 100 | + grad_input.mul_(grad_output.unsqueeze(dim=-1)) |
| 101 | + |
| 102 | + return grad_input, None, None |
| 103 | + |
| 104 | + |
| 105 | +def sharded_cross_entropy(sharded_logits: torch.Tensor, target: torch.Tensor, process_group: dist.ProcessGroup): |
| 106 | + return _ShardedCrossEntropy.apply(sharded_logits, target, process_group) |
| 107 | + |
| 108 | + |
| 109 | +def sharded_cross_entropy_wrapper_fn(process_group: dist.ProcessGroup): |
| 110 | + @wraps(sharded_cross_entropy) |
| 111 | + def wrapper( |
| 112 | + sharded_logits: torch.Tensor, |
| 113 | + target: torch.Tensor, |
| 114 | + weight: Optional[torch.Tensor] = None, |
| 115 | + size_average: Optional[bool] = None, |
| 116 | + ignore_index: int = -100, |
| 117 | + reduce: Optional[bool] = None, |
| 118 | + reduction: str = "mean", |
| 119 | + label_smoothing: float = 0.0, |
| 120 | + ): |
| 121 | + if weight is not None or ignore_index != -100 or label_smoothing != 0.0: |
| 122 | + raise ValueError( |
| 123 | + "Does not support weighted mode, index ignoring and label smoothing in current parallel cross entropy implementation." |
| 124 | + ) |
| 125 | + loss: torch.Tensor = sharded_cross_entropy(sharded_logits, target, process_group) |
| 126 | + |
| 127 | + if size_average is not None or reduce is not None: |
| 128 | + size_average = True if size_average is None else size_average |
| 129 | + reduce = True if reduce is None else reduce |
| 130 | + |
| 131 | + if size_average and reduce: |
| 132 | + reduction = "mean" |
| 133 | + elif reduce: |
| 134 | + reduction = "sum" |
| 135 | + else: |
| 136 | + reduction = "none" |
| 137 | + |
| 138 | + if reduction == "mean": |
| 139 | + return loss.mean() |
| 140 | + elif reduction == "sum": |
| 141 | + return loss.sum() |
| 142 | + return loss |
| 143 | + |
| 144 | + return wrapper |
| 145 | + |
| 146 | + |
| 147 | +class VocabParallelCrossEntropyLoss(nn.Module): |
| 148 | + """ |
| 149 | + Simple parallel cross entropy implementation which does not support weighted mode and label smoothing yet. |
| 150 | + """ |
| 151 | + |
| 152 | + def __init__(self, ctx: ParallelExecutionCtx, reduction: str = "mean") -> None: |
| 153 | + super(VocabParallelCrossEntropyLoss, self).__init__() |
| 154 | + self.process_group = ctx.tp_group |
| 155 | + self.reduction = reduction |
| 156 | + |
| 157 | + def forward(self, sharded_logits: torch.Tensor, target: torch.Tensor): |
| 158 | + loss: torch.Tensor = _ShardedCrossEntropy.apply(sharded_logits, target, self.process_group) |
| 159 | + if self.reduction == "mean": |
| 160 | + return loss.mean() |
| 161 | + elif self.reduction == "sum": |
| 162 | + return loss.sum() |
| 163 | + return loss |
0 commit comments