Skip to content

Commit 7b2108b

Browse files
mengniwang95Yi4Liu
andauthored
Enable layer-by-layer convert for vllm Deepseek model (#2137)
* enable layer-by-layer Signed-off-by: Mengni Wang <mengni.wang@intel.com> * disable pre-commit Signed-off-by: Mengni Wang <mengni.wang@intel.com> * align low_cpu_mem check Change-Id: I0f2d40d7b1bfa9e1c07ccb35dd6310f92bb793ae Signed-off-by: Yi Liu <yiliu4@habana.ai> --------- Signed-off-by: Mengni Wang <mengni.wang@intel.com> Signed-off-by: Yi Liu <yiliu4@habana.ai> Co-authored-by: Yi Liu <yiliu4@habana.ai>
1 parent d0e6c2e commit 7b2108b

File tree

5 files changed

+101
-12
lines changed

5 files changed

+101
-12
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
ci:
2-
autofix_prs: true
2+
autofix_prs: false
33
autoupdate_schedule: quarterly
44

55
repos:

neural_compressor/torch/algorithms/fp8_quant/_core/common.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,11 @@ def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype, dev
225225
# FIXME (Yi) revert change
226226
"FusedMoE": ModuleInfo("linear", PatchedMixtralMoE, False),
227227
"GaudiMixtralSparseMoeBlock": ModuleInfo("dynamic_moe", PatchedGaudiMixtralSparseMoeBlock),
228-
"VllmMixtureOfExpertsOp": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOp),
228+
"VllmMixtureOfExpertsOp": (
229+
ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV2)
230+
if os.getenv("LOW_CPU_MEM", "0") == "1"
231+
else ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV1)
232+
),
229233
}
230234

231235

neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method_name,
133133
for name, mod in model.named_modules():
134134
mod_type_str = mod.__class__.__name__
135135
logger.debug(f"start to handle module {name}, type: {mod_type_str}")
136+
origin_name = name
137+
# TODO: (Mengni) optimize the name conversion method between MoEV1 and MoEV2
138+
if "w1_list" in name or "w3_list" in name:
139+
name = name.replace("w1_list", "w13_list") if "w1_list" in name else name.replace("w3_list", "w13_list")
136140
if name in mod_list and name not in scales and config.cfg["use_stats_files"] and name not in measurement:
137141
if mod_default_dict[mod_type_str].should_measure_and_quant:
138142
if not config.cfg["ignore_modules_wo_measures"]:
@@ -143,7 +147,7 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method_name,
143147
continue
144148
# When offloading weight to disk, need to transfer the weight from disk to cpu using hf_hook
145149
apply_hf_hook(mod)
146-
if name in mod_list:
150+
if origin_name in mod_list:
147151
set_hqt_config(mod, config) # set config in the module, as it consumed by the patched module
148152
mod_extra_config, save_file = load_layer_scales(mod, name, config,
149153
mod_type_str, measurement,
@@ -155,7 +159,7 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method_name,
155159
quantize_params(mod, mod_extra_config)
156160
logger.debug(f"patching module {name}")
157161
patch_module(mod, mod_extra_config, mod_default_dict)
158-
# show_mem_info()
162+
name = origin_name
159163
patched_modules.append(name)
160164
patched_module_types.add(type(mod))
161165
htcore.mark_step()
@@ -167,6 +171,9 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method_name,
167171
logger.debug("Patched modules: %s", patched_modules)
168172
logger.debug("Total patched modules: %d", len(patched_modules))
169173
model = model.to(cur_accelerator.name())
174+
for _, mod in model.named_modules():
175+
if hasattr(mod, "post_process"):
176+
mod.post_process()
170177
torch.distributed.barrier()
171178
convert_fp16_to_bf16(model)
172179
cur_accelerator.synchronize()

neural_compressor/torch/algorithms/fp8_quant/_core/scale.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def load_layer_scales(mod, mod_name, config, mod_type_str, measurement, scales,
2929
)
3030
mod_extra_config = None
3131
if mod_name in scales or not config.cfg["use_stats_files"] or mod_name in measurement:
32+
if "w1_list" in mod_name or "w3_list" in mod_name:
33+
mod_name = mod_name.replace("w1_list", "w13_list") if "w1_list" in mod_name else mod_name.replace("w3_list", "w13_list")
3234
op_for_scale_obj = ops_quantizer.get_op_quantizer(module_type, scaling_method_name, mod,
3335
measurement.get(mod_name, None), scale_config)
3436
if mod_name not in scales:

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

+84-8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
import torch.nn as nn
17+
import copy
1718
import types
1819
import habana_frameworks.torch.core as htcore
1920
from .quant_config import QuantMode
@@ -575,12 +576,18 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
575576
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
576577
# remove the MoE weights that are quanted by PatchedMoeMatmul
577578
if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
578-
delattr(mod, "w13_weight")
579-
delattr(mod, "w2_weight")
580-
setattr(mod, "w13_weight", None)
581-
setattr(mod, "w2_weight", None)
582-
setattr(self, "w13_weight", None)
583-
setattr(self, "w2_weight", None)
579+
if hasattr(mod, "w13_weight"):
580+
delattr(mod, "w13_weight")
581+
setattr(mod, "w13_weight", None)
582+
if hasattr(mod, "w2_weight"):
583+
delattr(mod, "w2_weight")
584+
setattr(self, "w2_weight", None)
585+
if hasattr(mod, "w1_weight"):
586+
delattr(mod, "w1_weight")
587+
setattr(self, "w1_weight", None)
588+
if hasattr(mod, "w3_weight"):
589+
delattr(mod, "w3_weight")
590+
setattr(self, "w3_weight", None)
584591
self.forward = self.forward_orig
585592

586593

@@ -714,8 +721,7 @@ def extra_repr(self) -> str:
714721
get_current_repr(self, *member_names),
715722
)
716723

717-
718-
class PatchedVllmMixtureOfExpertsOp(PatchedModuleBase):
724+
class PatchedVllmMixtureOfExpertsOpV1(PatchedModuleBase):
719725
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
720726
super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
721727
self.experts_min = self.orig_mod.experts_min
@@ -808,6 +814,76 @@ def extra_repr(self) -> str:
808814
)
809815

810816

817+
class PatchedVllmMixtureOfExpertsOpV2(PatchedVllmMixtureOfExpertsOpV1):
818+
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
819+
PatchedModuleBase.__init__(self, mod, parent, mod_extra_config, *args, **kwargs)
820+
self.experts_min = self.orig_mod.experts_min
821+
self.experts_max = self.orig_mod.experts_max
822+
if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
823+
self.forward = self.forward_quant
824+
self.dynamic_moe_op = get_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE, self.scale_format)
825+
self.quant_input = self._mod_extra_config.inputs[0]
826+
self.register_scale("scale_input", mod_extra_config.scale.inputs[0], self.scale_format)
827+
# FIXME: (Yi) should we take the scale_intermediate from the mod_extra_config.scale.outputs?
828+
self.register_scale(
829+
"scale_intermediate",
830+
[mod_extra_config.scale.inputs[x] for x in range(1, self.num_experts+1)],
831+
self.scale_format,
832+
)
833+
for i in range(self.num_experts):
834+
self.w1_list[i].weight.data = self.w1_list[i].weight.squeeze().t().contiguous()
835+
self.w3_list[i].weight.data = self.w3_list[i].weight.squeeze().t().contiguous()
836+
self.w2_list[i].weight.data = self.w2_list[i].weight.squeeze().t().contiguous()
837+
elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE):
838+
self.forward = self.forward_measure
839+
840+
def post_process(self):
841+
self.dynamic_moe_op = get_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE_FUSED_WEIGHTS, self.scale_format)
842+
self.w13_weight = []
843+
self.w2_weight = []
844+
self.scale_w1 = []
845+
self.scale_w2 = []
846+
for i in range(self.num_experts):
847+
self.w2_weight.append(self.w2_list[i].weight.contiguous())
848+
w1_weight = self.w1_list[i].weight
849+
w3_weight = self.w3_list[i].weight
850+
self.w13_weight.append(torch.cat((w1_weight, w3_weight), dim=0).contiguous())
851+
# TODO: (Mengni) enhance the scale process for different scale formats
852+
self.scale_w1.append(max(self.w1_list[i].scale_weight, self.w3_list[i].scale_weight))
853+
self.scale_w2.append(self.w2_list[i].scale_weight)
854+
delattr(self.w1_list[i], "weight")
855+
delattr(self.w3_list[i], "weight")
856+
delattr(self.w2_list[i], "weight")
857+
htcore.mark_step()
858+
delattr(self, "w1_list")
859+
delattr(self, "w3_list")
860+
delattr(self, "w2_list")
861+
862+
def forward_quant(self,
863+
hidden_states,
864+
expert_routing_table,
865+
router_weights,
866+
permuted_weights=True,
867+
activation="silu"):
868+
qinput = self.quant_input(hidden_states)
869+
output = self.dynamic_moe_op(
870+
hidden_states=qinput,
871+
expert_routing_table=expert_routing_table,
872+
router_weights=router_weights,
873+
w12=self.w13_weight,
874+
w3=self.w2_weight,
875+
d_scale_w12=self.scale_w1,
876+
d_scale_w3=self.scale_w2,
877+
d_scale_hidden_states=self.scale_input,
878+
d_scale_intermediate_hidden_states=self.scale_intermediate,
879+
permuted_weights=permuted_weights,
880+
activation=activation,
881+
experts_min=self.experts_min,
882+
experts_max=self.experts_max,
883+
)
884+
return output
885+
886+
811887
class PatchedKVCache(PatchedModuleBase):
812888
# Module to patch KVCache module from llama model
813889
def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):

0 commit comments

Comments
 (0)