|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2024 The HuggingFace Team. All rights reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +from dataclasses import dataclass, field |
| 16 | +from functools import partial |
| 17 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
| 18 | + |
| 19 | +import torch |
| 20 | +import torch.distributed as dist |
| 21 | +import torch.nn as nn |
| 22 | +from torch.fx import GraphModule |
| 23 | + |
| 24 | + |
| 25 | +class HashableSlice: |
| 26 | + def __init__(self, start: Optional[int] = None, stop: Optional[int] = None, step: Optional[int] = None) -> None: |
| 27 | + self.start = start |
| 28 | + self.stop = stop |
| 29 | + self.step = step |
| 30 | + |
| 31 | + def __hash__(self) -> int: |
| 32 | + return hash(f"{self.start},{self.stop},{self.step}") |
| 33 | + |
| 34 | + def __eq__(self, value: object) -> bool: |
| 35 | + return ( |
| 36 | + isinstance(value, HashableSlice) |
| 37 | + and self.start == value.start |
| 38 | + and self.stop == value.stop |
| 39 | + and self.step == value.step |
| 40 | + ) |
| 41 | + |
| 42 | + def to_slice(self) -> slice: |
| 43 | + return slice(self.start, self.stop, self.step) |
| 44 | + |
| 45 | + |
| 46 | +@dataclass |
| 47 | +class ParameterSlice: |
| 48 | + """ |
| 49 | + A slice of parameter which corresponds to a tensor in weight dict. Only support slicing |
| 50 | + along a specific axis (the potential parallel axis) right now. |
| 51 | +
|
| 52 | + Attributes: |
| 53 | + - source (`Optional[str]`, defaults to `None`): |
| 54 | + Original parameter name which can be found in the weight dict. |
| 55 | +
|
| 56 | + - shape (`Optional[Tuple]`, defaults to `None`): |
| 57 | + Shape of parameter tensor corresponding to `source`. |
| 58 | +
|
| 59 | + - index (`slice`, defaults to `slice(None, None, None)`): |
| 60 | + Index to slice the tensor on the parallel axis. Assume tensor in weight dict has the same |
| 61 | + layout as their correspondings in memory. |
| 62 | + """ |
| 63 | + |
| 64 | + source: Optional[str] = None |
| 65 | + shape: Optional[Tuple] = None |
| 66 | + index: slice = slice(None, None, None) |
| 67 | + |
| 68 | + |
| 69 | +@dataclass |
| 70 | +class ParameterMeta: |
| 71 | + """ |
| 72 | + Parameter meta information. |
| 73 | +
|
| 74 | + Attributes: |
| 75 | + - is_tied (`bool`, defaults to `False`): |
| 76 | + Whether the parameter is shared accross multiple modules. |
| 77 | +
|
| 78 | + - is_parallel (`bool`, defaults to `False`): |
| 79 | + Whether the parameter needs to be parallelized. |
| 80 | +
|
| 81 | + - is_modified_meta (`bool`, defaults to `False`): |
| 82 | + Whether the meta has already been modified since initialization. |
| 83 | +
|
| 84 | + - need_initialize (`bool`, defaults to `False`): |
| 85 | + Whether need to manually initialize weights if not provided in weight map. |
| 86 | +
|
| 87 | + - init_fn (`Optional[Callable]`, defaults to `None`): |
| 88 | + Initialization function, can override `weight_init_fn` in `Config` if not None. |
| 89 | +
|
| 90 | + - dim (`int`, defaults to `0`): |
| 91 | + Axis on which `mapping` is based, also the parallel axis if `is_parallel`. |
| 92 | +
|
| 93 | + - mapping (`Dict[HashableSlice, ParameterSlice]`): |
| 94 | + Mapping between the current parameter and weight tensor stored in weight map. |
| 95 | + """ |
| 96 | + |
| 97 | + is_tied: bool = False |
| 98 | + is_parallel: bool = False |
| 99 | + is_modified_meta: bool = False |
| 100 | + need_initialize: bool = False |
| 101 | + init_fn: Optional[Callable] = None |
| 102 | + dim: int = 0 |
| 103 | + mapping: Dict[HashableSlice, ParameterSlice] = field(default_factory=dict) |
| 104 | + |
| 105 | + |
| 106 | +@dataclass |
| 107 | +class ParallelExecutionCtx: |
| 108 | + """ |
| 109 | + Parallel execution context which contains runtime information. |
| 110 | +
|
| 111 | + Attributes: |
| 112 | + - tp_group (`dist.ProcessGroup`): |
| 113 | + Tensor parallel process group the current process belongs to. |
| 114 | +
|
| 115 | + - current_device (`torch.device`): |
| 116 | + Device correpsonding to the current process. |
| 117 | +
|
| 118 | + - example_inputs (`List[Any]`): |
| 119 | + A list of tensors which are used as example inputs for graphs captured by dynamo. |
| 120 | +
|
| 121 | + - parallel_layer_cache (`Dict[str, nn.Module]`): |
| 122 | + Cache which maps layers(`nn.Linear`, `nn.Embedding`) to their parallel counterparts. |
| 123 | + Note that we will build the cache in the first compilation process, and for recompilations |
| 124 | + later on, we will directly replace the modules with their parallel counterparts in the cache, |
| 125 | + because we have to make sure we don't initiate new parameters and replace original ones when |
| 126 | + recompilation happens in training process. |
| 127 | +
|
| 128 | + - weight_map (`Dict[str, str]`): |
| 129 | + Mapping between parameter names and their locations on disk, useful when loading weights |
| 130 | + from disk. |
| 131 | +
|
| 132 | + - last_optimized_graph_module (`Optional[GraphModule]`, defaults to `None`): |
| 133 | + Optimized graph module corresponding to the latest compilation. |
| 134 | +
|
| 135 | + - compile_times (`int`, defaults to `0`): |
| 136 | + Number of compilation times happened during the whole process. |
| 137 | + """ |
| 138 | + |
| 139 | + tp_group: dist.ProcessGroup |
| 140 | + current_device: torch.device |
| 141 | + example_inputs: List[Any] = field(default_factory=list) |
| 142 | + parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict) |
| 143 | + weight_map: Dict[str, str] = field(default_factory=dict) |
| 144 | + last_optimized_graph_module: Optional[GraphModule] = None |
| 145 | + compile_times: int = 0 |
| 146 | + |
| 147 | + |
| 148 | +@dataclass |
| 149 | +class Config: |
| 150 | + """ |
| 151 | + Static config which contains instructions which do not change in runtime. |
| 152 | +
|
| 153 | + Attributes: |
| 154 | + - lint_and_recompile (`bool`, defaults to `True`): |
| 155 | + Whether to run graph linting and module recompilation after every pass. |
| 156 | +
|
| 157 | + - clean_markers_after_all_passes (`bool`, defaults to `True`): |
| 158 | + Whether to clean markers of analytical passes after all passes have run. |
| 159 | +
|
| 160 | + - weight_init_fn (`Callable`, defaults to `partial(nn.init.normal_, std=0.02)`) |
| 161 | + Initialization function of weights in `nn.Linear` and `nn.Embedding` layers, |
| 162 | + if not provided weights loading path. |
| 163 | + """ |
| 164 | + |
| 165 | + lint_and_recompile: bool = True |
| 166 | + clean_markers_after_all_passes: bool = True |
| 167 | + weight_init_fn: Callable = partial(nn.init.normal_, std=0.02) |
0 commit comments