Skip to content

Commit 40880a3

Browse files
fix
1 parent a375b6d commit 40880a3

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

optimum/fx/parallelization/backend/base.py

+5
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,11 @@ def create_parallel_embedding(
183183

184184
return VocabParallelEmbedding(parallel_ctx, mod)
185185

186+
def create_parallel_cross_entropy(
187+
self, mod_or_fn: Union[nn.CrossEntropyLoss, F.cross_entropy], parallel_ctx: ParallelExecutionCtx
188+
):
189+
return super().create_parallel_cross_entropy(mod_or_fn, parallel_ctx)
190+
186191
def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module:
187192
"""
188193
Initialize or load parameters from checkpoint, and tie them if needed.

optimum/fx/parallelization/backend/nanotron.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
# Nanotron specific imports
1616
import importlib.util
1717
from collections import defaultdict
18-
from typing import TYPE_CHECKING, Optional, Tuple
18+
from typing import TYPE_CHECKING, Optional, Tuple, Union
1919

2020
import torch.distributed as dist
2121
import torch.nn as nn
22+
import torch.nn.functional as F
2223
from torch.fx import GraphModule
2324

2425
from ..core import Config, ParallelExecutionCtx, ParameterMeta
@@ -149,6 +150,11 @@ def create_parallel_embedding(
149150
contiguous_chunks=contiguous_chunks,
150151
)
151152

153+
def create_parallel_cross_entropy(
154+
self, mod_or_fn: Union[nn.CrossEntropyLoss, F.cross_entropy], parallel_ctx: "ParallelExecutionCtx"
155+
):
156+
return super().create_parallel_cross_entropy(mod_or_fn, parallel_ctx)
157+
152158
def post_process(
153159
self, graph_module: GraphModule, parallel_ctx: "ParallelExecutionCtx", config: "Config"
154160
) -> nn.Module:

0 commit comments

Comments
 (0)