Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic Model Parallelism Through FX #1933

Merged
Changes from 10 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5e39787
WIP
zhenglongjiepheonix Jun 3, 2024
7a5d394
add dist ops
zhenglongjiepheonix Jun 11, 2024
7e15d26
Merge remote-tracking branch 'upstream/main' into longjie/add_automat…
zhenglongjiepheonix Jun 11, 2024
98e5846
add index propagation
zhenglongjiepheonix Jun 15, 2024
2036dbb
support tp for linears
zhenglongjiepheonix Jul 1, 2024
34fffe8
Merge remote-tracking branch 'upstream/main' into longjie/add_automat…
zhenglongjiepheonix Jul 1, 2024
0876f5d
add embedding & weight tie
zhenglongjiepheonix Jul 8, 2024
87e66fb
Merge remote-tracking branch 'upstream/main' into longjie/add_automat…
zhenglongjiepheonix Jul 8, 2024
ae6d9d2
address comments
zhenglongjiepheonix Jul 8, 2024
455c0c7
lint
zhenglongjiepheonix Jul 8, 2024
27a9bb8
fix
zhenglongjiepheonix Jul 12, 2024
473388b
Merge remote-tracking branch 'upstream/main' into longjie/add_automat…
zhenglongjiepheonix Jul 12, 2024
0512b23
fix
zhenglongjiepheonix Jul 12, 2024
8ec6727
debug
zhenglongjiepheonix Jul 13, 2024
5095f1e
fix
zhenglongjiepheonix Jul 13, 2024
f6ebfc0
fix tests
zhenglongjiepheonix Jul 15, 2024
e71e5ea
add experimental API
zhenglongjiepheonix Jul 16, 2024
eb2a7a6
Merge remote-tracking branch 'upstream/main' into longjie/add_automat…
zhenglongjiepheonix Jul 16, 2024
779c77d
nit
zhenglongjiepheonix Jul 16, 2024
e09df2a
fix api
zhenglongjiepheonix Jul 17, 2024
22fe1a3
Merge remote-tracking branch 'upstream/main' into longjie/add_automat…
zhenglongjiepheonix Jul 17, 2024
9fd29d1
fix api
zhenglongjiepheonix Jul 18, 2024
01cfc25
format
zhenglongjiepheonix Jul 18, 2024
8c16267
clean tests
zhenglongjiepheonix Jul 18, 2024
8ef00e0
fix weight_map
zhenglongjiepheonix Jul 18, 2024
6ef2081
add weights loading
zhenglongjiepheonix Jul 22, 2024
2c561d3
format
zhenglongjiepheonix Jul 22, 2024
fc96b6f
fix
zhenglongjiepheonix Jul 22, 2024
8d2cabb
fix
zhenglongjiepheonix Jul 23, 2024
c9c7571
Merge remote-tracking branch 'upstream/main' into longjie/add_automat…
zhenglongjiepheonix Jul 23, 2024
97e6431
enable tests
zhenglongjiepheonix Jul 23, 2024
efd5d28
address comments
zhenglongjiepheonix Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions optimum/fx/parallelization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import torch
from torch.fx import GraphModule

from .core import Config, ParallelExecutionCtx
from .passes import build_parallel_pass_pipeline


def parallelize_backend(
graph_module: GraphModule, example_inputs: List[torch.Tensor], ctx: ParallelExecutionCtx, config: Config
) -> GraphModule:
ctx.example_inputs = example_inputs
pass_pipeline = build_parallel_pass_pipeline()
graph_module = pass_pipeline(graph_module=graph_module, ctx=ctx, config=config)
ctx.compile_times += 1
return graph_module
154 changes: 154 additions & 0 deletions optimum/fx/parallelization/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.distributed as dist
import torch.nn as nn


class HashableSlice:
def __init__(self, start: Optional[int] = None, stop: Optional[int] = None, step: Optional[int] = None) -> None:
self.start = start
self.stop = stop
self.step = step

def __hash__(self) -> int:
return hash(f"{self.start},{self.stop},{self.step}")

def __eq__(self, value: object) -> bool:
return (
isinstance(value, HashableSlice)
and self.start == value.start
and self.stop == value.stop
and self.step == value.step
)

def to_slice(self) -> slice:
return slice(self.start, self.stop, self.step)


@dataclass
class ParameterSlice:
"""
A slice of parameter which corresponds to a tensor in weight dict. Only support slicing
along a specific axis (the potential parallel axis) right now.

Attributes:
- source (`Optional[str]`):
Original parameter name which can be found in the weight dict.

- index (`Optional[slice]`):
Index to slice the tensor on the parallel axis. Assume tensor in weight dict has the same
layout as their correspondings in memory.
"""

source: Optional[str] = None
index: Optional[slice] = None


@dataclass
class ParameterMeta:
"""
Parameter meta information.

Attributes:
- is_tied (`bool`, defaults to `False`):
Whether the parameter is shared accross multiple modules.

- is_modified_meta (`bool`, defaults to `False`):
Whether the meta has already been modified since initialization.

- need_initialize (`bool`, defaults to `False`):
Whether need to manually initialize weights if not provided in weight map.

- init_fn (`Optional[Callable]`):
Initialization function, can override `weight_init_fn` in `Config` if not None.

- dim (`int`, defaults to `0`):
Axis on which `mapping` is based.

- mapping (`Dict[HashableSlice, ParameterSlice]`):
Mapping between the current parameter and weight tensor stored in weight map.
"""

is_tied: bool = False
is_modified_meta: bool = False
need_initialize: bool = False
init_fn: Optional[Callable] = None
dim: int = 0
mapping: Dict[HashableSlice, ParameterSlice] = field(default_factory=dict)


@dataclass
class ParallelExecutionCtx:
"""
Parallel execution context which contains runtime information.

Attributes:
- tp_group (`dist.ProcessGroup`):
Tensor parallel process group the current process belongs to.

- current_device (`torch.device`):
Device correpsonding to the current process.

- example_inputs (`List[Any]`):
A list of tensors which are used as example inputs for graphs captured by dynamo.

- parallel_layer_cache (`Dict[int, nn.Module]`):
Cache which maps layers(`nn.Linear`, `nn.Embedding`) to their parallel counterparts.
Note that we will build the cache in the first compilation process, and for recompilations
later on, we will directly replace the modules with their parallel counterparts in the cache,
because we have to make sure we don't initiate new parameters and replace original ones when
recompilation happens in training process.

- weight_map (`Dict[str, str]`):
Mapping between parameter names and their locations on disk, useful when loading weights
from disk.

- compile_times (`int`, defaults to `0`):
Number of compilation times happened during the whole process.
"""

tp_group: dist.ProcessGroup
current_device: torch.device
example_inputs: List[Any] = field(default_factory=list)
parallel_layer_cache: Dict[int, nn.Module] = field(default_factory=dict)
weight_map: Dict[str, str] = field(default_factory=dict)
compile_times: int = 0


@dataclass
class Config:
"""
Static config which contains instructions which do not change in runtime.

Attributes:
- lint_and_recompile (`bool`, defaults to `True`):
Whether to run graph linting and module recompilation after every pass.

- clean_markers_after_all_passes (`bool`, defaults to `True`):
Whether to clean markers of analytical passes after all passes have run.

- weight_init_fn (`Callable`, defaults to `partial(nn.init.normal_, std=0.02)`)
Initialization function of weights in `nn.Linear` and `nn.Embedding` layers,
if not provided weights loading path.
"""

lint_and_recompile: bool = True
clean_markers_after_all_passes: bool = True
weight_init_fn: Callable = partial(nn.init.normal_, std=0.02)
21 changes: 21 additions & 0 deletions optimum/fx/parallelization/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dist_ops import (
differentiable_all_gather,
differentiable_all_reduce_sum,
differentiable_identity,
differentiable_scatter,
scatter,
)
147 changes: 147 additions & 0 deletions optimum/fx/parallelization/distributed/dist_ops.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file seems more related to the parallel layers. Hopefully at some point we could use existing backends instead.
Like nanotron or megatron etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be great ! maybe even the torch native parallelism layers.

Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed as dist

from ..utils import ensure_divisibility


def all_reduce(group: dist.ProcessGroup, tensor: torch.Tensor) -> torch.Tensor:
world_size = dist.get_world_size(group)
if world_size == 1:
return tensor

dist.all_reduce(tensor, group=group)
return tensor


def all_gather(group: dist.ProcessGroup, tensor: torch.Tensor, gather_dim: int = -1) -> torch.Tensor:
world_size = dist.get_world_size(group)
if world_size == 1:
return tensor
rank = dist.get_rank(group=group)

tensor = tensor.contiguous()
gather_dim = (gather_dim + tensor.ndim) % tensor.ndim
shape = tuple(
tensor.size(dim) * world_size if dim == gather_dim else tensor.size(dim) for dim in range(tensor.ndim)
)
index = [
slice(rank * tensor.size(dim), (rank + 1) * tensor.size(dim), None)
if dim == gather_dim
else slice(None, None, None)
for dim in range(tensor.ndim)
]
tensors = torch.empty(*shape, dtype=tensor.dtype, device=tensor.device)
tensors[index] = tensor
dist.all_gather_into_tensor(tensors, tensor, group=group)
return tensors


def split(group: dist.ProcessGroup, tensor: torch.Tensor, split_dim: int = -1) -> torch.Tensor:
world_size = dist.get_world_size(group)
if world_size == 1:
return tensor

rank = dist.get_rank(group)
size = tensor.size()
ensure_divisibility(size[split_dim], world_size)
tensors = torch.split(tensor, size[split_dim] // world_size, dim=split_dim)
tensor = tensors[rank].contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why contiguous?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensors after split may not be contiguous, I think it's better be contiguous


return tensor


def scatter(
group: dist.ProcessGroup, tensor: torch.Tensor, output_tensor: torch.Tensor, scatter_dim: int = 0
) -> torch.Tensor:
world_size = dist.get_world_size(group)
if world_size == 1:
return tensor

rank = dist.get_rank(group)
if rank == 0:
size = tensor.size()
ensure_divisibility(size[scatter_dim], world_size)
tensors = torch.split(tensor, size[scatter_dim] // world_size, dim=scatter_dim)
scatter_list = [tensor.contiguous() for tensor in tensors]
output_tensor.copy_(scatter_list[rank])
else:
scatter_list = None
dist.scatter(tensor=output_tensor, scatter_list=scatter_list, src=0, group=group)
return output_tensor


class DifferentiableIdentity(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, group: dist.ProcessGroup):
ctx.group = group
return tensor

@staticmethod
def backward(ctx, grad_output):
group = ctx.group
return DifferentiableAllReduceSum.apply(grad_output, group), None


class DifferentiableAllReduceSum(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor:
ctx.group = group
return all_reduce(group=group, tensor=tensor)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Any:
return grad_output, None


class DifferentiableScatter(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup, dim: int = -1) -> torch.Tensor:
ctx.group = group
ctx.dim = dim
return split(group=group, tensor=tensor, split_dim=dim)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
return DifferentiableAllGather.apply(grad_output, group=ctx.group, dim=ctx.dim), None, None


class DifferentiableAllGather(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup, dim: int = -1) -> torch.Tensor:
ctx.group = group
ctx.dim = dim
return all_gather(group=group, tensor=tensor, gather_dim=dim)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
return DifferentiableScatter.apply(grad_output, group=ctx.group, dim=ctx.dim), None, None


def differentiable_all_reduce_sum(tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor:
return DifferentiableAllReduceSum.apply(tensor, group)


def differentiable_identity(tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor:
return DifferentiableIdentity.apply(tensor, group)


def differentiable_all_gather(tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1) -> torch.Tensor:
return DifferentiableAllGather.apply(tensor, group, dim)


def differentiable_scatter(tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1) -> torch.Tensor:
return DifferentiableScatter.apply(tensor, group, dim)
16 changes: 16 additions & 0 deletions optimum/fx/parallelization/parallel_layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .embedding import VocabParallelEmbedding
from .linear import ColumnParallelLinear, RowParallelLinear
Loading