diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5a467a878b6..b3897b04ebe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ ci: - autofix_prs: true + autofix_prs: false autoupdate_schedule: quarterly repos: diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/common.py b/neural_compressor/torch/algorithms/fp8_quant/_core/common.py index ca41443377a..b75511344e9 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/common.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/common.py @@ -225,7 +225,11 @@ def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype, dev # FIXME (Yi) revert change "FusedMoE": ModuleInfo("linear", PatchedMixtralMoE, False), "GaudiMixtralSparseMoeBlock": ModuleInfo("dynamic_moe", PatchedGaudiMixtralSparseMoeBlock), - "VllmMixtureOfExpertsOp": ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOp), + "VllmMixtureOfExpertsOp": ( + ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV2) + if os.getenv("LOW_CPU_MEM", "0") == "1" + else ModuleInfo("dynamic_moe", PatchedVllmMixtureOfExpertsOpV1) + ), } diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py b/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py index 970be916385..7af50b68ada 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py @@ -133,6 +133,10 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method_name, for name, mod in model.named_modules(): mod_type_str = mod.__class__.__name__ logger.debug(f"start to handle module {name}, type: {mod_type_str}") + origin_name = name + # TODO: (Mengni) optimize the name conversion method between MoEV1 and MoEV2 + if "w1_list" in name or "w3_list" in name: + name = name.replace("w1_list", "w13_list") if "w1_list" in name else name.replace("w3_list", "w13_list") if name in mod_list and name not in scales and config.cfg["use_stats_files"] and name not in measurement: if mod_default_dict[mod_type_str].should_measure_and_quant: if not config.cfg["ignore_modules_wo_measures"]: @@ -143,7 +147,7 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method_name, continue # When offloading weight to disk, need to transfer the weight from disk to cpu using hf_hook apply_hf_hook(mod) - if name in mod_list: + if origin_name in mod_list: set_hqt_config(mod, config) # set config in the module, as it consumed by the patched module mod_extra_config, save_file = load_layer_scales(mod, name, config, mod_type_str, measurement, @@ -155,7 +159,7 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method_name, quantize_params(mod, mod_extra_config) logger.debug(f"patching module {name}") patch_module(mod, mod_extra_config, mod_default_dict) - # show_mem_info() + name = origin_name patched_modules.append(name) patched_module_types.add(type(mod)) htcore.mark_step() @@ -167,6 +171,9 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method_name, logger.debug("Patched modules: %s", patched_modules) logger.debug("Total patched modules: %d", len(patched_modules)) model = model.to(cur_accelerator.name()) + for _, mod in model.named_modules(): + if hasattr(mod, "post_process"): + mod.post_process() torch.distributed.barrier() convert_fp16_to_bf16(model) cur_accelerator.synchronize() diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py index 61e8c6c3ea3..110319496e4 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py @@ -29,6 +29,8 @@ def load_layer_scales(mod, mod_name, config, mod_type_str, measurement, scales, ) mod_extra_config = None if mod_name in scales or not config.cfg["use_stats_files"] or mod_name in measurement: + if "w1_list" in mod_name or "w3_list" in mod_name: + mod_name = mod_name.replace("w1_list", "w13_list") if "w1_list" in mod_name else mod_name.replace("w3_list", "w13_list") op_for_scale_obj = ops_quantizer.get_op_quantizer(module_type, scaling_method_name, mod, measurement.get(mod_name, None), scale_config) if mod_name not in scales: diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index ba66f987588..e9c040a0dea 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn +import copy import types import habana_frameworks.torch.core as htcore from .quant_config import QuantMode @@ -575,12 +576,18 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): super().__init__(mod, parent, mod_extra_config, *args, **kwargs) # remove the MoE weights that are quanted by PatchedMoeMatmul if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]: - delattr(mod, "w13_weight") - delattr(mod, "w2_weight") - setattr(mod, "w13_weight", None) - setattr(mod, "w2_weight", None) - setattr(self, "w13_weight", None) - setattr(self, "w2_weight", None) + if hasattr(mod, "w13_weight"): + delattr(mod, "w13_weight") + setattr(mod, "w13_weight", None) + if hasattr(mod, "w2_weight"): + delattr(mod, "w2_weight") + setattr(self, "w2_weight", None) + if hasattr(mod, "w1_weight"): + delattr(mod, "w1_weight") + setattr(self, "w1_weight", None) + if hasattr(mod, "w3_weight"): + delattr(mod, "w3_weight") + setattr(self, "w3_weight", None) self.forward = self.forward_orig @@ -714,8 +721,7 @@ def extra_repr(self) -> str: get_current_repr(self, *member_names), ) - -class PatchedVllmMixtureOfExpertsOp(PatchedModuleBase): +class PatchedVllmMixtureOfExpertsOpV1(PatchedModuleBase): def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): super().__init__(mod, parent, mod_extra_config, *args, **kwargs) self.experts_min = self.orig_mod.experts_min @@ -808,6 +814,76 @@ def extra_repr(self) -> str: ) +class PatchedVllmMixtureOfExpertsOpV2(PatchedVllmMixtureOfExpertsOpV1): + def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): + PatchedModuleBase.__init__(self, mod, parent, mod_extra_config, *args, **kwargs) + self.experts_min = self.orig_mod.experts_min + self.experts_max = self.orig_mod.experts_max + if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]: + self.forward = self.forward_quant + self.dynamic_moe_op = get_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE, self.scale_format) + self.quant_input = self._mod_extra_config.inputs[0] + self.register_scale("scale_input", mod_extra_config.scale.inputs[0], self.scale_format) + # FIXME: (Yi) should we take the scale_intermediate from the mod_extra_config.scale.outputs? + self.register_scale( + "scale_intermediate", + [mod_extra_config.scale.inputs[x] for x in range(1, self.num_experts+1)], + self.scale_format, + ) + for i in range(self.num_experts): + self.w1_list[i].weight.data = self.w1_list[i].weight.squeeze().t().contiguous() + self.w3_list[i].weight.data = self.w3_list[i].weight.squeeze().t().contiguous() + self.w2_list[i].weight.data = self.w2_list[i].weight.squeeze().t().contiguous() + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def post_process(self): + self.dynamic_moe_op = get_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE_FUSED_WEIGHTS, self.scale_format) + self.w13_weight = [] + self.w2_weight = [] + self.scale_w1 = [] + self.scale_w2 = [] + for i in range(self.num_experts): + self.w2_weight.append(self.w2_list[i].weight.contiguous()) + w1_weight = self.w1_list[i].weight + w3_weight = self.w3_list[i].weight + self.w13_weight.append(torch.cat((w1_weight, w3_weight), dim=0).contiguous()) + # TODO: (Mengni) enhance the scale process for different scale formats + self.scale_w1.append(max(self.w1_list[i].scale_weight, self.w3_list[i].scale_weight)) + self.scale_w2.append(self.w2_list[i].scale_weight) + delattr(self.w1_list[i], "weight") + delattr(self.w3_list[i], "weight") + delattr(self.w2_list[i], "weight") + htcore.mark_step() + delattr(self, "w1_list") + delattr(self, "w3_list") + delattr(self, "w2_list") + + def forward_quant(self, + hidden_states, + expert_routing_table, + router_weights, + permuted_weights=True, + activation="silu"): + qinput = self.quant_input(hidden_states) + output = self.dynamic_moe_op( + hidden_states=qinput, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=self.w13_weight, + w3=self.w2_weight, + d_scale_w12=self.scale_w1, + d_scale_w3=self.scale_w2, + d_scale_hidden_states=self.scale_input, + d_scale_intermediate_hidden_states=self.scale_intermediate, + permuted_weights=permuted_weights, + activation=activation, + experts_min=self.experts_min, + experts_max=self.experts_max, + ) + return output + + class PatchedKVCache(PatchedModuleBase): # Module to patch KVCache module from llama model def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):