Skip to content

Commit 4efd6ab

Browse files
enable layer-by-layer (#25)
Signed-off-by: Mengni Wang <mengni.wang@intel.com> Signed-off-by: yi <yi> Co-authored-by: Yi Liu <yi4.liu@intel.com> Co-authored-by: yi <yi>
1 parent 896dca1 commit 4efd6ab

File tree

6 files changed

+166
-53
lines changed

6 files changed

+166
-53
lines changed

scripts/QuantizeDeepSeek.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,9 @@ In this section, we load the BF16 model on DRAM and quantize it to FP8 model usi
155155

156156
### Prerequisites
157157

158-
- Hardware: 1x8G3 or 1x8G2(WIP), 2T DRAM
158+
- Hardware: 1x8G3 or 1x8G2
159159
- Docker: 1.20.0-521
160160

161-
> [!NOTE] The DRAM requirement can be decreased to less than 1T in a few days.
162-
163161
### Running the Example
164162

165163
- BF16 KVCache
@@ -180,6 +178,10 @@ huggingface-cli download Yi30/inc-tp8-ep8-full-kvcache-from-tp16-ep16 --local-di
180178
QUANT_CONFIG=inc_quant_with_fp8kv_one_node_config.json python inc_example_one_node.py --fp8_kvcache
181179
```
182180

181+
> [!NOTE]
182+
> If your DRAM is less than 2T, please use `LOW_CPU_MEM=1` to open layer-by-layer conversion.
183+
184+
183185
## Accuracy Evaluation (WIP)
184186

185187
## Calibration with Custom Dataset (WIP)

scripts/head_node_source.sh

+5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ export VLLM_GRAPH_PROMPT_RATIO=0
4444

4545
# INC
4646
unset QUANT_CONFIG
47+
unset LOW_CPU_MEM
4748

4849
# Fot prepare
4950
max_num_batched_tokens=2048
@@ -60,3 +61,7 @@ unset VLLM_DECODE_BLOCK_BUCKET_MIN VLLM_DECODE_BLOCK_BUCKET_STEP VLLM_DECODE_BLO
6061
set_bucketing
6162
echo " environments are reseted "
6263
env | grep VLLM
64+
65+
# check quant_config and low_cpu_mem
66+
echo "QUANT_CONFIG: $QUANT_CONFIG"
67+
echo "LOW_CPU_MEM: $LOW_CPU_MEM"

scripts/n2_ep8_tp8.py

+1
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def sample_gsm8k_requests(
273273
distributed_executor_backend='mp',
274274
trust_remote_code=True,
275275
quantization=quantization,
276+
weights_load_device="cpu",
276277
max_model_len=16384,
277278
dtype="bfloat16",
278279
)

scripts/worker_node_source.sh

+5-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ export VLLM_GRAPH_PROMPT_RATIO=0
3636

3737
# INC
3838
unset QUANT_CONFIG
39-
39+
unset LOW_CPU_MEM
4040
# params
4141
# max_num_batched_tokens=2048
4242
# max_num_seqs=1024
@@ -58,3 +58,7 @@ unset VLLM_DECODE_BLOCK_BUCKET_MIN VLLM_DECODE_BLOCK_BUCKET_STEP VLLM_DECODE_BLO
5858
set_bucketing
5959
echo " environments are reseted "
6060
env | grep VLLM
61+
62+
# check quant_config and low_cpu_mem
63+
echo "QUANT_CONFIG: $QUANT_CONFIG"
64+
echo "LOW_CPU_MEM: $LOW_CPU_MEM"

vllm/model_executor/layers/fused_moe/layer.py

+147-48
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
import os
3535
VLLM_LOAD_FOR_INC = os.environ.get("VLLM_LOAD_FOR_INC", "0") == "1"
36+
LOW_CPU_MEM = os.environ.get("LOW_CPU_MEM", "0") == "1"
3637

3738
# ==-------------------------------------------------------------------------==
3839
# VLLM-HPU-EXT PATCH Start
@@ -42,7 +43,6 @@
4243
import habana_frameworks.torch.core as htcore
4344
import habana_frameworks.torch as htorch
4445

45-
4646
class MoeMatmul(torch.nn.Module):
4747
def __init__(self):
4848
super().__init__()
@@ -57,9 +57,17 @@ def forward(self, state, expert_id, w):
5757
class VllmMixtureOfExpertsOp(torch.nn.Module):
5858
def __init__(self, num_total_experts):
5959
super().__init__()
60-
self.w13_list = torch.nn.ModuleList(
61-
[MoeMatmul() for _ in range(num_total_experts)]
62-
)
60+
if not LOW_CPU_MEM:
61+
self.w13_list = torch.nn.ModuleList(
62+
[MoeMatmul() for _ in range(num_total_experts)]
63+
)
64+
else:
65+
self.w1_list = torch.nn.ModuleList(
66+
[MoeMatmul() for _ in range(num_total_experts)]
67+
)
68+
self.w3_list = torch.nn.ModuleList(
69+
[MoeMatmul() for _ in range(num_total_experts)]
70+
)
6371
self.w2_list = torch.nn.ModuleList(
6472
[MoeMatmul() for _ in range(num_total_experts)]
6573
)
@@ -82,20 +90,36 @@ def forward(
8290
assert self.experts_min is not None, "`experts_min` is not set"
8391
assert self.experts_max is not None, "`experts_max` is not set"
8492
experts_min, experts_max = self.experts_min, self.experts_max
85-
w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range]
86-
w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range]
87-
return torch.ops.hpu.mixture_of_experts(
88-
hidden_states=hidden_states,
89-
expert_routing_table=expert_routing_table,
90-
router_weights=router_weights,
91-
w12=w1_list,
92-
w3=w2_list,
93-
permuted_weights=permuted_weights,
94-
activation=activation,
95-
experts_min=experts_min,
96-
experts_max=experts_max,
97-
)
98-
93+
if not LOW_CPU_MEM:
94+
w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range]
95+
w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range]
96+
return torch.ops.hpu.mixture_of_experts.fused_weights(
97+
hidden_states=hidden_states,
98+
expert_routing_table=expert_routing_table,
99+
router_weights=router_weights,
100+
w12=w1_list,
101+
w3=w2_list,
102+
permuted_weights=permuted_weights,
103+
activation=activation,
104+
experts_min=experts_min,
105+
experts_max=experts_max,
106+
)
107+
else:
108+
w1_list = [self.w1_list[i].weight.squeeze() for i in experts_range]
109+
w3_list = [self.w3_list[i].weight.squeeze() for i in experts_range]
110+
w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range]
111+
return torch.ops.hpu.mixture_of_experts.default(
112+
hidden_states=hidden_states,
113+
expert_routing_table=expert_routing_table,
114+
router_weights=router_weights,
115+
w1=w1_list,
116+
w2=w2_list,
117+
w3=w3_list,
118+
permuted_weights=permuted_weights,
119+
activation=activation,
120+
experts_min=experts_min,
121+
experts_max=experts_max,
122+
)
99123

100124
class _DynamicFusedMOE(torch.nn.Module):
101125
def __init__(self, num_total_experts):
@@ -165,14 +189,33 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
165189
hidden_size: int, intermediate_size_per_partition: int,
166190
params_dtype: torch.dtype, **extra_weight_attrs):
167191
# Fused gate_up_proj (column parallel)
168-
w13_weight = torch.nn.Parameter(torch.empty(
169-
num_experts,
170-
2 * intermediate_size_per_partition,
171-
hidden_size,
172-
dtype=params_dtype),
173-
requires_grad=False)
174-
layer.register_parameter("w13_weight", w13_weight)
175-
set_weight_attrs(w13_weight, extra_weight_attrs)
192+
if not LOW_CPU_MEM:
193+
w13_weight = torch.nn.Parameter(torch.empty(
194+
num_experts,
195+
2 * intermediate_size_per_partition,
196+
hidden_size,
197+
dtype=params_dtype),
198+
requires_grad=False)
199+
layer.register_parameter("w13_weight", w13_weight)
200+
set_weight_attrs(w13_weight, extra_weight_attrs)
201+
else:
202+
w1_weight = torch.nn.Parameter(torch.empty(
203+
num_experts,
204+
intermediate_size_per_partition,
205+
hidden_size,
206+
dtype=params_dtype),
207+
requires_grad=False)
208+
layer.register_parameter("w1_weight", w1_weight)
209+
set_weight_attrs(w1_weight, extra_weight_attrs)
210+
211+
w3_weight = torch.nn.Parameter(torch.empty(
212+
num_experts,
213+
intermediate_size_per_partition,
214+
hidden_size,
215+
dtype=params_dtype),
216+
requires_grad=False)
217+
layer.register_parameter("w3_weight", w3_weight)
218+
set_weight_attrs(w3_weight, extra_weight_attrs)
176219

177220
# down_proj (row parallel)
178221
w2_weight = torch.nn.Parameter(torch.empty(
@@ -327,8 +370,8 @@ def forward_hpu_original(
327370
max_expert = (i + 1) * n_expert_slice
328371
# w13_list_slice = [w13_list[i].weight.squeeze() for i in range(min_expert, max_expert)]
329372
# w2_list_slice = [w2_list[i].weight.squeeze() for i in range(min_expert, max_expert)]
330-
w13_list_slice = [layer.w13_weight[j].squeeze().clone() for j in range(min_expert, max_expert)]
331-
w2_list_slice = [layer.w2_weight[j].squeeze().clone() for j in range(min_expert, max_expert)]
373+
w13_list_slice = [layer.w13_weight[j].squeeze() for j in range(min_expert, max_expert)]
374+
w2_list_slice = [layer.w2_weight[j].squeeze() for j in range(min_expert, max_expert)]
332375
final_hidden_states += torch.ops.hpu.mixture_of_experts(hidden_states=x,
333376
expert_routing_table=topk_ids.to(torch.int64),
334377
router_weights=topk_weights.to(x.dtype),
@@ -612,21 +655,54 @@ def __init__(
612655
max_expert = (i + 1) * n_expert_slice
613656
# Note: clone weight will cause OoM.
614657
# rank_debug(f"i:{i}, num_experts:{num_experts} loading experts from {min_expert} to {max_expert}, layer.w13_weight.shape : {layer.w13_weight.shape}")
615-
w13_list_slice = [
616-
layer.w13_weight[j]
617-
for j in range(min_expert, max_expert)
618-
]
658+
if not LOW_CPU_MEM:
659+
w13_list_slice = [
660+
layer.w13_weight[j]
661+
for j in range(min_expert, max_expert)
662+
]
663+
else:
664+
w1_list_slice = [
665+
layer.w1_weight[j]
666+
for j in range(min_expert, max_expert)
667+
]
668+
w3_list_slice = [
669+
layer.w3_weight[j]
670+
for j in range(min_expert, max_expert)
671+
]
619672
w2_list_slice = [
620673
layer.w2_weight[j]
621674
for j in range(min_expert, max_expert)
622675
]
623-
for index in range(len(w13_list_slice)):
624-
_temp_expert_group.MoeOp.w13_list[index].set_weight(
625-
w13_list_slice[index]
626-
)
627-
_temp_expert_group.MoeOp.w2_list[index].set_weight(
628-
w2_list_slice[index]
629-
)
676+
for index in range(len(w2_list_slice)):
677+
if not LOW_CPU_MEM:
678+
_temp_expert_group.MoeOp.w13_list[index].set_weight(
679+
w13_list_slice[index]
680+
)
681+
_temp_expert_group.MoeOp.w2_list[index].set_weight(
682+
w2_list_slice[index]
683+
)
684+
else:
685+
_temp_expert_group.MoeOp.w1_list[index].set_weight(
686+
torch.nn.Parameter(torch.empty(
687+
w1_list_slice[index].shape,
688+
dtype=w1_list_slice[index].dtype,
689+
),
690+
requires_grad=False)
691+
)
692+
_temp_expert_group.MoeOp.w3_list[index].set_weight(
693+
torch.nn.Parameter(torch.empty(
694+
w3_list_slice[index].shape,
695+
dtype=w3_list_slice[index].dtype,
696+
),
697+
requires_grad=False)
698+
)
699+
_temp_expert_group.MoeOp.w2_list[index].set_weight(
700+
torch.nn.Parameter(torch.empty(
701+
w2_list_slice[index].shape,
702+
dtype=w2_list_slice[index].dtype,
703+
),
704+
requires_grad=False)
705+
)
630706
# FIXME: (Yi) pass `experts_min` and `experts_max` to MoeOp.
631707
setattr(_temp_expert_group.MoeOp, "experts_min", min_expert + ep_shift)
632708
setattr(_temp_expert_group.MoeOp, "experts_max", max_expert - 1 + ep_shift)
@@ -891,13 +967,31 @@ def weight_loader(self, param: torch.nn.Parameter,
891967

892968
# Case model weights
893969
if "weight" in weight_name:
894-
self._load_model_weight_or_group_weight_scale(
895-
shard_id=shard_id,
896-
shard_dim=shard_dim,
897-
loaded_weight=loaded_weight,
898-
expert_data=expert_data,
899-
tp_rank=tp_rank,
900-
expert_id=expert_id)
970+
if not LOW_CPU_MEM:
971+
self._load_model_weight_or_group_weight_scale(
972+
shard_id=shard_id,
973+
shard_dim=shard_dim,
974+
loaded_weight=loaded_weight,
975+
expert_data=expert_data,
976+
tp_rank=tp_rank,
977+
expert_id=expert_id)
978+
else:
979+
expert_group = expert_id // (self.num_experts//self.num_expert_group)
980+
lst_id = expert_id % (self.num_experts//self.num_expert_group)
981+
expert = getattr(self, f"_temp_expert_group_{expert_group}")
982+
if shard_id in ("w1", "w3"):
983+
if shard_id == "w1":
984+
expert_data = expert.MoeOp.w1_list[lst_id].weight
985+
expert_data.data = loaded_weight
986+
else:
987+
assert shard_id == "w3"
988+
expert_data = expert.MoeOp.w3_list[lst_id].weight
989+
expert_data.data = loaded_weight
990+
991+
elif shard_id == "w2":
992+
# load w2
993+
expert_data = expert.MoeOp.w2_list[lst_id].weight
994+
expert_data.data = loaded_weight
901995
return
902996

903997
@staticmethod
@@ -973,14 +1067,19 @@ def make_expert_params_mapping(
9731067
ckpt_up_proj_name: str,
9741068
num_experts: int) -> List[Tuple[str, str, int, str]]:
9751069

1070+
mapping = {
1071+
ckpt_gate_proj_name: "experts.w1_" if LOW_CPU_MEM else "experts.w13_",
1072+
ckpt_down_proj_name: "experts.w2_",
1073+
ckpt_up_proj_name: "experts.w3_" if LOW_CPU_MEM else "experts.w13_",
1074+
}
9761075
return [
9771076
# (param_name, weight_name, expert_id, shard_id)
978-
("experts.w13_" if weight_name
979-
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
1077+
(mapping[weight_name],
9801078
f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
9811079
for expert_id in range(num_experts) for shard_id, weight_name in [
9821080
("w1", ckpt_gate_proj_name),
9831081
("w2", ckpt_down_proj_name),
9841082
("w3", ckpt_up_proj_name),
9851083
]
9861084
]
1085+

vllm/worker/hpu_model_runner.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2361,7 +2361,8 @@ def try_revert_dummy_output_tokens():
23612361
LoraMask.setLoraMask(
23622362
lora_logits_mask.index_select(
23632363
0, sampling_metadata.selected_token_indices))
2364-
2364+
# NOTE: In case the driver rank's compute_logits hangs on two nodes.
2365+
torch.hpu.synchronize()
23652366
# Compute the logits.
23662367
with self.profiler.record_event(
23672368
'internal',
@@ -2373,6 +2374,7 @@ def try_revert_dummy_output_tokens():
23732374
sampling_metadata.selected_token_indices = None
23742375
logits = self.model.compute_logits(hidden_states,
23752376
sampling_metadata)
2377+
23762378
htorch.core.mark_step()
23772379
# Only perform sampling in the driver worker.
23782380
if not self.is_driver_worker:

0 commit comments

Comments
 (0)