Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable layer-by-layer convert for vllm Deepseek model #2137

Merged
merged 3 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ci:
autofix_prs: true
autofix_prs: false
autoupdate_schedule: quarterly

repos:
Expand Down
6 changes: 5 additions & 1 deletion neural_compressor/torch/algorithms/fp8_quant/_core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,11 @@ def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype, dev
# FIXME (Yi) revert change
"FusedMoE": ModuleInfo("linear", PatchedMixtralMoE, False),
"GaudiMixtralSparseMoeBlock": ModuleInfo("dynamic_moe", PatchedGaudiMixtralSparseMoeBlock),
"VllmMixtureOfExpertsOp": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOp),
"VllmMixtureOfExpertsOp": (
ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV2)
if os.getenv("LOW_CPU_MEM", "0") == "1"
else ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV1)
),
}


Expand Down
11 changes: 9 additions & 2 deletions neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method_name,
for name, mod in model.named_modules():
mod_type_str = mod.__class__.__name__
logger.debug(f"start to handle module {name}, type: {mod_type_str}")
origin_name = name
# TODO: (Mengni) optimize the name conversion method between MoEV1 and MoEV2
if "w1_list" in name or "w3_list" in name:
name = name.replace("w1_list", "w13_list") if "w1_list" in name else name.replace("w3_list", "w13_list")
if name in mod_list and name not in scales and config.cfg["use_stats_files"] and name not in measurement:
if mod_default_dict[mod_type_str].should_measure_and_quant:
if not config.cfg["ignore_modules_wo_measures"]:
Expand All @@ -143,7 +147,7 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method_name,
continue
# When offloading weight to disk, need to transfer the weight from disk to cpu using hf_hook
apply_hf_hook(mod)
if name in mod_list:
if origin_name in mod_list:
set_hqt_config(mod, config) # set config in the module, as it consumed by the patched module
mod_extra_config, save_file = load_layer_scales(mod, name, config,
mod_type_str, measurement,
Expand All @@ -155,7 +159,7 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method_name,
quantize_params(mod, mod_extra_config)
logger.debug(f"patching module {name}")
patch_module(mod, mod_extra_config, mod_default_dict)
# show_mem_info()
name = origin_name
patched_modules.append(name)
patched_module_types.add(type(mod))
htcore.mark_step()
Expand All @@ -167,6 +171,9 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method_name,
logger.debug("Patched modules: %s", patched_modules)
logger.debug("Total patched modules: %d", len(patched_modules))
model = model.to(cur_accelerator.name())
for _, mod in model.named_modules():
if hasattr(mod, "post_process"):
mod.post_process()
torch.distributed.barrier()
convert_fp16_to_bf16(model)
cur_accelerator.synchronize()
Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/torch/algorithms/fp8_quant/_core/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def load_layer_scales(mod, mod_name, config, mod_type_str, measurement, scales,
)
mod_extra_config = None
if mod_name in scales or not config.cfg["use_stats_files"] or mod_name in measurement:
if "w1_list" in mod_name or "w3_list" in mod_name:
mod_name = mod_name.replace("w1_list", "w13_list") if "w1_list" in mod_name else mod_name.replace("w3_list", "w13_list")
op_for_scale_obj = ops_quantizer.get_op_quantizer(module_type, scaling_method_name, mod,
measurement.get(mod_name, None), scale_config)
if mod_name not in scales:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import torch.nn as nn
import copy
import types
import habana_frameworks.torch.core as htcore
from .quant_config import QuantMode
Expand Down Expand Up @@ -575,12 +576,18 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
# remove the MoE weights that are quanted by PatchedMoeMatmul
if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
delattr(mod, "w13_weight")
delattr(mod, "w2_weight")
setattr(mod, "w13_weight", None)
setattr(mod, "w2_weight", None)
setattr(self, "w13_weight", None)
setattr(self, "w2_weight", None)
if hasattr(mod, "w13_weight"):
delattr(mod, "w13_weight")
setattr(mod, "w13_weight", None)
if hasattr(mod, "w2_weight"):
delattr(mod, "w2_weight")
setattr(self, "w2_weight", None)
if hasattr(mod, "w1_weight"):
delattr(mod, "w1_weight")
setattr(self, "w1_weight", None)
if hasattr(mod, "w3_weight"):
delattr(mod, "w3_weight")
setattr(self, "w3_weight", None)
self.forward = self.forward_orig


Expand Down Expand Up @@ -714,8 +721,7 @@ def extra_repr(self) -> str:
get_current_repr(self, *member_names),
)


class PatchedVllmMixtureOfExpertsOp(PatchedModuleBase):
class PatchedVllmMixtureOfExpertsOpV1(PatchedModuleBase):
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
self.experts_min = self.orig_mod.experts_min
Expand Down Expand Up @@ -808,6 +814,76 @@ def extra_repr(self) -> str:
)


class PatchedVllmMixtureOfExpertsOpV2(PatchedVllmMixtureOfExpertsOpV1):
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
PatchedModuleBase.__init__(self, mod, parent, mod_extra_config, *args, **kwargs)
self.experts_min = self.orig_mod.experts_min
self.experts_max = self.orig_mod.experts_max
if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
self.forward = self.forward_quant
self.dynamic_moe_op = get_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE, self.scale_format)
self.quant_input = self._mod_extra_config.inputs[0]
self.register_scale("scale_input", mod_extra_config.scale.inputs[0], self.scale_format)
# FIXME: (Yi) should we take the scale_intermediate from the mod_extra_config.scale.outputs?
self.register_scale(
"scale_intermediate",
[mod_extra_config.scale.inputs[x] for x in range(1, self.num_experts+1)],
self.scale_format,
)
for i in range(self.num_experts):
self.w1_list[i].weight.data = self.w1_list[i].weight.squeeze().t().contiguous()
self.w3_list[i].weight.data = self.w3_list[i].weight.squeeze().t().contiguous()
self.w2_list[i].weight.data = self.w2_list[i].weight.squeeze().t().contiguous()
elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE):
self.forward = self.forward_measure

def post_process(self):
self.dynamic_moe_op = get_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE_FUSED_WEIGHTS, self.scale_format)
self.w13_weight = []
self.w2_weight = []
self.scale_w1 = []
self.scale_w2 = []
for i in range(self.num_experts):
self.w2_weight.append(self.w2_list[i].weight.contiguous())
w1_weight = self.w1_list[i].weight
w3_weight = self.w3_list[i].weight
self.w13_weight.append(torch.cat((w1_weight, w3_weight), dim=0).contiguous())
# TODO: (Mengni) enhance the scale process for different scale formats
self.scale_w1.append(max(self.w1_list[i].scale_weight, self.w3_list[i].scale_weight))
self.scale_w2.append(self.w2_list[i].scale_weight)
delattr(self.w1_list[i], "weight")
delattr(self.w3_list[i], "weight")
delattr(self.w2_list[i], "weight")
htcore.mark_step()
delattr(self, "w1_list")
delattr(self, "w3_list")
delattr(self, "w2_list")

def forward_quant(self,
hidden_states,
expert_routing_table,
router_weights,
permuted_weights=True,
activation="silu"):
qinput = self.quant_input(hidden_states)
output = self.dynamic_moe_op(
hidden_states=qinput,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=self.w13_weight,
w3=self.w2_weight,
d_scale_w12=self.scale_w1,
d_scale_w3=self.scale_w2,
d_scale_hidden_states=self.scale_input,
d_scale_intermediate_hidden_states=self.scale_intermediate,
permuted_weights=permuted_weights,
activation=activation,
experts_min=self.experts_min,
experts_max=self.experts_max,
)
return output


class PatchedKVCache(PatchedModuleBase):
# Module to patch KVCache module from llama model
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
Expand Down