Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable layer-by-layer #25

Merged
merged 5 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions scripts/QuantizeDeepSeek.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,9 @@ In this section, we load the BF16 model on DRAM and quantize it to FP8 model usi

### Prerequisites

- Hardware: 1x8G3 or 1x8G2(WIP), 2T DRAM
- Hardware: 1x8G3 or 1x8G2
- Docker: 1.20.0-521

> [!NOTE] The DRAM requirement can be decreased to less than 1T in a few days.

### Running the Example

- BF16 KVCache
Expand All @@ -180,6 +178,10 @@ huggingface-cli download Yi30/inc-tp8-ep8-full-kvcache-from-tp16-ep16 --local-di
QUANT_CONFIG=inc_quant_with_fp8kv_one_node_config.json python inc_example_one_node.py --fp8_kvcache
```

> [!NOTE]
> If your DRAM is less than 2T, please use `LOW_CPU_MEM=1` to open layer-by-layer conversion.


## Accuracy Evaluation (WIP)

## Calibration with Custom Dataset (WIP)
5 changes: 5 additions & 0 deletions scripts/head_node_source.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export VLLM_GRAPH_PROMPT_RATIO=0

# INC
unset QUANT_CONFIG
unset LOW_CPU_MEM

# Fot prepare
max_num_batched_tokens=2048
Expand All @@ -60,3 +61,7 @@ unset VLLM_DECODE_BLOCK_BUCKET_MIN VLLM_DECODE_BLOCK_BUCKET_STEP VLLM_DECODE_BLO
set_bucketing
echo " environments are reseted "
env | grep VLLM

# check quant_config and low_cpu_mem
echo "QUANT_CONFIG: $QUANT_CONFIG"
echo "LOW_CPU_MEM: $LOW_CPU_MEM"
1 change: 1 addition & 0 deletions scripts/n2_ep8_tp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def sample_gsm8k_requests(
distributed_executor_backend='mp',
trust_remote_code=True,
quantization=quantization,
weights_load_device="cpu",
max_model_len=16384,
dtype="bfloat16",
)
Expand Down
6 changes: 5 additions & 1 deletion scripts/worker_node_source.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export VLLM_GRAPH_PROMPT_RATIO=0

# INC
unset QUANT_CONFIG

unset LOW_CPU_MEM
# params
# max_num_batched_tokens=2048
# max_num_seqs=1024
Expand All @@ -58,3 +58,7 @@ unset VLLM_DECODE_BLOCK_BUCKET_MIN VLLM_DECODE_BLOCK_BUCKET_STEP VLLM_DECODE_BLO
set_bucketing
echo " environments are reseted "
env | grep VLLM

# check quant_config and low_cpu_mem
echo "QUANT_CONFIG: $QUANT_CONFIG"
echo "LOW_CPU_MEM: $LOW_CPU_MEM"
195 changes: 147 additions & 48 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import os
VLLM_LOAD_FOR_INC = os.environ.get("VLLM_LOAD_FOR_INC", "0") == "1"
LOW_CPU_MEM = os.environ.get("LOW_CPU_MEM", "0") == "1"

# ==-------------------------------------------------------------------------==
# VLLM-HPU-EXT PATCH Start
Expand All @@ -42,7 +43,6 @@
import habana_frameworks.torch.core as htcore
import habana_frameworks.torch as htorch


class MoeMatmul(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -57,9 +57,17 @@ def forward(self, state, expert_id, w):
class VllmMixtureOfExpertsOp(torch.nn.Module):
def __init__(self, num_total_experts):
super().__init__()
self.w13_list = torch.nn.ModuleList(
[MoeMatmul() for _ in range(num_total_experts)]
)
if not LOW_CPU_MEM:
self.w13_list = torch.nn.ModuleList(
[MoeMatmul() for _ in range(num_total_experts)]
)
else:
self.w1_list = torch.nn.ModuleList(
[MoeMatmul() for _ in range(num_total_experts)]
)
self.w3_list = torch.nn.ModuleList(
[MoeMatmul() for _ in range(num_total_experts)]
)
self.w2_list = torch.nn.ModuleList(
[MoeMatmul() for _ in range(num_total_experts)]
)
Expand All @@ -82,20 +90,36 @@ def forward(
assert self.experts_min is not None, "`experts_min` is not set"
assert self.experts_max is not None, "`experts_max` is not set"
experts_min, experts_max = self.experts_min, self.experts_max
w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range]
w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range]
return torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
permuted_weights=permuted_weights,
activation=activation,
experts_min=experts_min,
experts_max=experts_max,
)

if not LOW_CPU_MEM:
w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range]
w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range]
return torch.ops.hpu.mixture_of_experts.fused_weights(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
permuted_weights=permuted_weights,
activation=activation,
experts_min=experts_min,
experts_max=experts_max,
)
else:
w1_list = [self.w1_list[i].weight.squeeze() for i in experts_range]
w3_list = [self.w3_list[i].weight.squeeze() for i in experts_range]
w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range]
return torch.ops.hpu.mixture_of_experts.default(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w1=w1_list,
w2=w2_list,
w3=w3_list,
permuted_weights=permuted_weights,
activation=activation,
experts_min=experts_min,
experts_max=experts_max,
)

class _DynamicFusedMOE(torch.nn.Module):
def __init__(self, num_total_experts):
Expand Down Expand Up @@ -165,14 +189,33 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if not LOW_CPU_MEM:
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
else:
w1_weight = torch.nn.Parameter(torch.empty(
num_experts,
intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w1_weight", w1_weight)
set_weight_attrs(w1_weight, extra_weight_attrs)

w3_weight = torch.nn.Parameter(torch.empty(
num_experts,
intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w3_weight", w3_weight)
set_weight_attrs(w3_weight, extra_weight_attrs)

# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(
Expand Down Expand Up @@ -327,8 +370,8 @@ def forward_hpu_original(
max_expert = (i + 1) * n_expert_slice
# w13_list_slice = [w13_list[i].weight.squeeze() for i in range(min_expert, max_expert)]
# w2_list_slice = [w2_list[i].weight.squeeze() for i in range(min_expert, max_expert)]
w13_list_slice = [layer.w13_weight[j].squeeze().clone() for j in range(min_expert, max_expert)]
w2_list_slice = [layer.w2_weight[j].squeeze().clone() for j in range(min_expert, max_expert)]
w13_list_slice = [layer.w13_weight[j].squeeze() for j in range(min_expert, max_expert)]
w2_list_slice = [layer.w2_weight[j].squeeze() for j in range(min_expert, max_expert)]
final_hidden_states += torch.ops.hpu.mixture_of_experts(hidden_states=x,
expert_routing_table=topk_ids.to(torch.int64),
router_weights=topk_weights.to(x.dtype),
Expand Down Expand Up @@ -612,21 +655,54 @@ def __init__(
max_expert = (i + 1) * n_expert_slice
# Note: clone weight will cause OoM.
# rank_debug(f"i:{i}, num_experts:{num_experts} loading experts from {min_expert} to {max_expert}, layer.w13_weight.shape : {layer.w13_weight.shape}")
w13_list_slice = [
layer.w13_weight[j]
for j in range(min_expert, max_expert)
]
if not LOW_CPU_MEM:
w13_list_slice = [
layer.w13_weight[j]
for j in range(min_expert, max_expert)
]
else:
w1_list_slice = [
layer.w1_weight[j]
for j in range(min_expert, max_expert)
]
w3_list_slice = [
layer.w3_weight[j]
for j in range(min_expert, max_expert)
]
w2_list_slice = [
layer.w2_weight[j]
for j in range(min_expert, max_expert)
]
for index in range(len(w13_list_slice)):
_temp_expert_group.MoeOp.w13_list[index].set_weight(
w13_list_slice[index]
)
_temp_expert_group.MoeOp.w2_list[index].set_weight(
w2_list_slice[index]
)
for index in range(len(w2_list_slice)):
if not LOW_CPU_MEM:
_temp_expert_group.MoeOp.w13_list[index].set_weight(
w13_list_slice[index]
)
_temp_expert_group.MoeOp.w2_list[index].set_weight(
w2_list_slice[index]
)
else:
_temp_expert_group.MoeOp.w1_list[index].set_weight(
torch.nn.Parameter(torch.empty(
w1_list_slice[index].shape,
dtype=w1_list_slice[index].dtype,
),
requires_grad=False)
)
_temp_expert_group.MoeOp.w3_list[index].set_weight(
torch.nn.Parameter(torch.empty(
w3_list_slice[index].shape,
dtype=w3_list_slice[index].dtype,
),
requires_grad=False)
)
_temp_expert_group.MoeOp.w2_list[index].set_weight(
torch.nn.Parameter(torch.empty(
w2_list_slice[index].shape,
dtype=w2_list_slice[index].dtype,
),
requires_grad=False)
)
# FIXME: (Yi) pass `experts_min` and `experts_max` to MoeOp.
setattr(_temp_expert_group.MoeOp, "experts_min", min_expert + ep_shift)
setattr(_temp_expert_group.MoeOp, "experts_max", max_expert - 1 + ep_shift)
Expand Down Expand Up @@ -891,13 +967,31 @@ def weight_loader(self, param: torch.nn.Parameter,

# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
expert_id=expert_id)
if not LOW_CPU_MEM:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
expert_id=expert_id)
else:
expert_group = expert_id // (self.num_experts//self.num_expert_group)
lst_id = expert_id % (self.num_experts//self.num_expert_group)
expert = getattr(self, f"_temp_expert_group_{expert_group}")
if shard_id in ("w1", "w3"):
if shard_id == "w1":
expert_data = expert.MoeOp.w1_list[lst_id].weight
expert_data.data = loaded_weight
else:
assert shard_id == "w3"
expert_data = expert.MoeOp.w3_list[lst_id].weight
expert_data.data = loaded_weight

elif shard_id == "w2":
# load w2
expert_data = expert.MoeOp.w2_list[lst_id].weight
expert_data.data = loaded_weight
return

@staticmethod
Expand Down Expand Up @@ -973,14 +1067,19 @@ def make_expert_params_mapping(
ckpt_up_proj_name: str,
num_experts: int) -> List[Tuple[str, str, int, str]]:

mapping = {
ckpt_gate_proj_name: "experts.w1_" if LOW_CPU_MEM else "experts.w13_",
ckpt_down_proj_name: "experts.w2_",
ckpt_up_proj_name: "experts.w3_" if LOW_CPU_MEM else "experts.w13_",
}
return [
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_" if weight_name
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
(mapping[weight_name],
f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
for expert_id in range(num_experts) for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]

4 changes: 3 additions & 1 deletion vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2361,7 +2361,8 @@ def try_revert_dummy_output_tokens():
LoraMask.setLoraMask(
lora_logits_mask.index_select(
0, sampling_metadata.selected_token_indices))

# NOTE: In case the driver rank's compute_logits hangs on two nodes.
torch.hpu.synchronize()
# Compute the logits.
with self.profiler.record_event(
'internal',
Expand All @@ -2373,6 +2374,7 @@ def try_revert_dummy_output_tokens():
sampling_metadata.selected_token_indices = None
logits = self.model.compute_logits(hidden_states,
sampling_metadata)

htorch.core.mark_step()
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
Expand Down