Skip to content

Commit 5eaf91b

Browse files
Automatic Model Parallelism Through FX (#1933)
* WIP * add dist ops * add index propagation * support tp for linears * add embedding & weight tie * address comments * lint * fix * fix * debug * fix * fix tests * add experimental API * nit * fix api * fix api * format * clean tests * fix weight_map * add weights loading * format * fix * fix * enable tests * address comments
1 parent cfaece8 commit 5eaf91b

File tree

13 files changed

+2244
-0
lines changed

13 files changed

+2244
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
name: Automatic Model Parallelism Test on GPUs
2+
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
paths:
8+
- 'optimum/fx/parallelization/**.py'
9+
push:
10+
branches:
11+
- main
12+
paths:
13+
- 'optimum/fx/parallelization/**.py'
14+
15+
concurrency:
16+
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
17+
cancel-in-progress: true
18+
19+
jobs:
20+
run_gpu_tests:
21+
strategy:
22+
fail-fast: false
23+
matrix:
24+
config:
25+
- name: GPU-enabled Optimum Test Suite
26+
image: nvidia/cuda:12.4.1-devel-ubuntu22.04
27+
gpu_target: ["nvidia-multi-gpu-l4-runners", "nvidia-multi-gpu-a10-runners"]
28+
29+
name: ${{ matrix.config.name }}
30+
runs-on:
31+
group: "${{matrix.gpu_target}}"
32+
33+
container:
34+
image: ${{ matrix.config.image }}
35+
options: --mount type=tmpfs,destination=/tmp --shm-size 64gb --gpus all --ipc host -v /mnt/hf_cache:/mnt/cache/
36+
env:
37+
NCCL_DEBUG: INFO
38+
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
39+
defaults:
40+
run:
41+
shell: bash
42+
43+
steps:
44+
- uses: actions/setup-python@v5
45+
with:
46+
python-version: '3.10'
47+
48+
- name: Checkout optimum
49+
uses: actions/checkout@v4
50+
with:
51+
fetch-depth: 1
52+
53+
- name: Run nvidia-smi
54+
run: |
55+
nvidia-smi
56+
57+
- name: Install dependencies
58+
run: |
59+
python3 -m pip install -U pip
60+
python3 -m pip install torch transformers
61+
python3 -m pip install .[tests]
62+
63+
- name: Run automatic model parallelism tests
64+
run: |
65+
pytest -s -v -o log_cli=true tests/fx/parallelization
+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from .api import parallelize_backend, parallelize_model
16+
from .core import Config, ParallelExecutionCtx

optimum/fx/parallelization/api.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import importlib
16+
import os
17+
from functools import partial
18+
from typing import List, Union
19+
20+
import torch
21+
from torch.fx import GraphModule
22+
23+
from .core import Config, ParallelExecutionCtx
24+
from .passes import build_parallel_pass_pipeline
25+
from .utils import (
26+
MetaAwareMethodsPatcher,
27+
download_model_from_hf,
28+
initialize_parameter_meta,
29+
move_model_to_device,
30+
try_collect_weight_map,
31+
)
32+
33+
34+
def parallelize_backend(
35+
graph_module: GraphModule, example_inputs: List[torch.Tensor], ctx: ParallelExecutionCtx, config: Config
36+
) -> GraphModule:
37+
ctx.example_inputs = example_inputs
38+
pass_pipeline = build_parallel_pass_pipeline()
39+
graph_module = pass_pipeline(graph_module=graph_module, ctx=ctx, config=config)
40+
ctx.compile_times += 1
41+
ctx.last_optimized_graph_module = graph_module
42+
return graph_module
43+
44+
45+
def parallelize_model(
46+
model: Union[torch.nn.Module, str],
47+
parallel_ctx: ParallelExecutionCtx,
48+
*model_args,
49+
**kwargs,
50+
):
51+
"""
52+
API for automatic model parallelism through Pytorch FX.
53+
54+
Args:
55+
model (Union[torch.nn.Module, str]):
56+
Model to parallelize, could either be a module or a model id on the Huggingface Hub.
57+
parallel_ctx (ParallelExecutionCtx):
58+
Parallel execution context containing process groups the current process belongs to.
59+
*model_args (Any):
60+
Additional postional arguments for intializing the model if a model id is passed.
61+
revision (str, defaults to `main`):
62+
Model revision for weights downloading if a model id is passed.
63+
cache_dir (Optional[str], defaults to `None`):
64+
Cache directory to store downloaded weights. Defaults to None.
65+
local_files_only (bool, defaults to `False`):
66+
Whether to use local files only, will avoid downloading from remote if set to `True`.
67+
skip_load_weights (bool, defaults to `False`):
68+
Whether to skip loading weights from disk to model.
69+
**kwargs (Dict[str, Any]):
70+
Addtional keyword arguments for overriding fields in parallel config, model config and `Model.__init__`.
71+
"""
72+
revision = kwargs.pop("revision", "main")
73+
cache_dir = kwargs.pop("cache_dir", None)
74+
local_files_only = kwargs.pop("local_files_only", False)
75+
skip_load_weights = kwargs.pop("skip_load_weights", False)
76+
77+
parallel_config = Config()
78+
for k, v in dict(kwargs).items():
79+
if k in parallel_config.__dict__:
80+
setattr(parallel_config, k, v)
81+
kwargs.pop(k)
82+
83+
if isinstance(model, str):
84+
from transformers import AutoConfig
85+
86+
is_local = os.path.isdir(model)
87+
if not is_local:
88+
hf_folder = download_model_from_hf(
89+
model_name_or_path=model,
90+
cache_dir=cache_dir,
91+
revision=revision,
92+
local_files_only=local_files_only,
93+
skip_download_weights=skip_load_weights,
94+
)
95+
else:
96+
hf_folder = model
97+
98+
# should be able to load config using only local files
99+
model_config, kwargs = AutoConfig.from_pretrained(
100+
hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs
101+
)
102+
103+
# try getting model class info from config
104+
model_arch = model_config.architectures
105+
model_cls = getattr(importlib.import_module("transformers"), model_arch[0])
106+
107+
if not skip_load_weights:
108+
parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder)
109+
110+
torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None
111+
if torch_dtype is not None:
112+
dtype_orig = model_cls._set_default_torch_dtype(torch_dtype)
113+
114+
with MetaAwareMethodsPatcher():
115+
model = model_cls(model_config, *model_args, **kwargs)
116+
# TODO: remove this once support training-time trace
117+
model.eval()
118+
119+
if dtype_orig is not None:
120+
torch.set_default_dtype(dtype_orig)
121+
122+
move_model_to_device(model, device=parallel_ctx.current_device)
123+
initialize_parameter_meta(model)
124+
backend = partial(parallelize_backend, ctx=parallel_ctx, config=parallel_config)
125+
model = torch.compile(model, fullgraph=True, backend=backend)
126+
return model

optimum/fx/parallelization/core.py

+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from dataclasses import dataclass, field
16+
from functools import partial
17+
from typing import Any, Callable, Dict, List, Optional, Tuple
18+
19+
import torch
20+
import torch.distributed as dist
21+
import torch.nn as nn
22+
from torch.fx import GraphModule
23+
24+
25+
class HashableSlice:
26+
def __init__(self, start: Optional[int] = None, stop: Optional[int] = None, step: Optional[int] = None) -> None:
27+
self.start = start
28+
self.stop = stop
29+
self.step = step
30+
31+
def __hash__(self) -> int:
32+
return hash(f"{self.start},{self.stop},{self.step}")
33+
34+
def __eq__(self, value: object) -> bool:
35+
return (
36+
isinstance(value, HashableSlice)
37+
and self.start == value.start
38+
and self.stop == value.stop
39+
and self.step == value.step
40+
)
41+
42+
def to_slice(self) -> slice:
43+
return slice(self.start, self.stop, self.step)
44+
45+
46+
@dataclass
47+
class ParameterSlice:
48+
"""
49+
A slice of parameter which corresponds to a tensor in weight dict. Only support slicing
50+
along a specific axis (the potential parallel axis) right now.
51+
52+
Attributes:
53+
- source (`Optional[str]`, defaults to `None`):
54+
Original parameter name which can be found in the weight dict.
55+
56+
- shape (`Optional[Tuple]`, defaults to `None`):
57+
Shape of parameter tensor corresponding to `source`.
58+
59+
- index (`slice`, defaults to `slice(None, None, None)`):
60+
Index to slice the tensor on the parallel axis. Assume tensor in weight dict has the same
61+
layout as their correspondings in memory.
62+
"""
63+
64+
source: Optional[str] = None
65+
shape: Optional[Tuple] = None
66+
index: slice = slice(None, None, None)
67+
68+
69+
@dataclass
70+
class ParameterMeta:
71+
"""
72+
Parameter meta information.
73+
74+
Attributes:
75+
- is_tied (`bool`, defaults to `False`):
76+
Whether the parameter is shared accross multiple modules.
77+
78+
- is_parallel (`bool`, defaults to `False`):
79+
Whether the parameter needs to be parallelized.
80+
81+
- is_modified_meta (`bool`, defaults to `False`):
82+
Whether the meta has already been modified since initialization.
83+
84+
- need_initialize (`bool`, defaults to `False`):
85+
Whether need to manually initialize weights if not provided in weight map.
86+
87+
- init_fn (`Optional[Callable]`, defaults to `None`):
88+
Initialization function, can override `weight_init_fn` in `Config` if not None.
89+
90+
- dim (`int`, defaults to `0`):
91+
Axis on which `mapping` is based, also the parallel axis if `is_parallel`.
92+
93+
- mapping (`Dict[HashableSlice, ParameterSlice]`):
94+
Mapping between the current parameter and weight tensor stored in weight map.
95+
"""
96+
97+
is_tied: bool = False
98+
is_parallel: bool = False
99+
is_modified_meta: bool = False
100+
need_initialize: bool = False
101+
init_fn: Optional[Callable] = None
102+
dim: int = 0
103+
mapping: Dict[HashableSlice, ParameterSlice] = field(default_factory=dict)
104+
105+
106+
@dataclass
107+
class ParallelExecutionCtx:
108+
"""
109+
Parallel execution context which contains runtime information.
110+
111+
Attributes:
112+
- tp_group (`dist.ProcessGroup`):
113+
Tensor parallel process group the current process belongs to.
114+
115+
- current_device (`torch.device`):
116+
Device correpsonding to the current process.
117+
118+
- example_inputs (`List[Any]`):
119+
A list of tensors which are used as example inputs for graphs captured by dynamo.
120+
121+
- parallel_layer_cache (`Dict[str, nn.Module]`):
122+
Cache which maps layers(`nn.Linear`, `nn.Embedding`) to their parallel counterparts.
123+
Note that we will build the cache in the first compilation process, and for recompilations
124+
later on, we will directly replace the modules with their parallel counterparts in the cache,
125+
because we have to make sure we don't initiate new parameters and replace original ones when
126+
recompilation happens in training process.
127+
128+
- weight_map (`Dict[str, str]`):
129+
Mapping between parameter names and their locations on disk, useful when loading weights
130+
from disk.
131+
132+
- last_optimized_graph_module (`Optional[GraphModule]`, defaults to `None`):
133+
Optimized graph module corresponding to the latest compilation.
134+
135+
- compile_times (`int`, defaults to `0`):
136+
Number of compilation times happened during the whole process.
137+
"""
138+
139+
tp_group: dist.ProcessGroup
140+
current_device: torch.device
141+
example_inputs: List[Any] = field(default_factory=list)
142+
parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict)
143+
weight_map: Dict[str, str] = field(default_factory=dict)
144+
last_optimized_graph_module: Optional[GraphModule] = None
145+
compile_times: int = 0
146+
147+
148+
@dataclass
149+
class Config:
150+
"""
151+
Static config which contains instructions which do not change in runtime.
152+
153+
Attributes:
154+
- lint_and_recompile (`bool`, defaults to `True`):
155+
Whether to run graph linting and module recompilation after every pass.
156+
157+
- clean_markers_after_all_passes (`bool`, defaults to `True`):
158+
Whether to clean markers of analytical passes after all passes have run.
159+
160+
- weight_init_fn (`Callable`, defaults to `partial(nn.init.normal_, std=0.02)`)
161+
Initialization function of weights in `nn.Linear` and `nn.Embedding` layers,
162+
if not provided weights loading path.
163+
"""
164+
165+
lint_and_recompile: bool = True
166+
clean_markers_after_all_passes: bool = True
167+
weight_init_fn: Callable = partial(nn.init.normal_, std=0.02)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from .dist_ops import (
16+
differentiable_all_gather,
17+
differentiable_all_reduce_sum,
18+
differentiable_identity,
19+
differentiable_scatter,
20+
scatter,
21+
)

0 commit comments

Comments
 (0)