|
14 | 14 |
|
15 | 15 | import torch
|
16 | 16 | import torch.nn as nn
|
| 17 | +import copy |
17 | 18 | import types
|
18 | 19 | import habana_frameworks.torch.core as htcore
|
19 | 20 | from .quant_config import QuantMode
|
@@ -575,12 +576,18 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
|
575 | 576 | super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
|
576 | 577 | # remove the MoE weights that are quanted by PatchedMoeMatmul
|
577 | 578 | if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
|
578 |
| - delattr(mod, "w13_weight") |
579 |
| - delattr(mod, "w2_weight") |
580 |
| - setattr(mod, "w13_weight", None) |
581 |
| - setattr(mod, "w2_weight", None) |
582 |
| - setattr(self, "w13_weight", None) |
583 |
| - setattr(self, "w2_weight", None) |
| 579 | + if hasattr(mod, "w13_weight"): |
| 580 | + delattr(mod, "w13_weight") |
| 581 | + setattr(mod, "w13_weight", None) |
| 582 | + if hasattr(mod, "w2_weight"): |
| 583 | + delattr(mod, "w2_weight") |
| 584 | + setattr(self, "w2_weight", None) |
| 585 | + if hasattr(mod, "w1_weight"): |
| 586 | + delattr(mod, "w1_weight") |
| 587 | + setattr(self, "w1_weight", None) |
| 588 | + if hasattr(mod, "w3_weight"): |
| 589 | + delattr(mod, "w3_weight") |
| 590 | + setattr(self, "w3_weight", None) |
584 | 591 | self.forward = self.forward_orig
|
585 | 592 |
|
586 | 593 |
|
@@ -714,8 +721,7 @@ def extra_repr(self) -> str:
|
714 | 721 | get_current_repr(self, *member_names),
|
715 | 722 | )
|
716 | 723 |
|
717 |
| - |
718 |
| -class PatchedVllmMixtureOfExpertsOp(PatchedModuleBase): |
| 724 | +class PatchedVllmMixtureOfExpertsOpV1(PatchedModuleBase): |
719 | 725 | def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
|
720 | 726 | super().__init__(mod, parent, mod_extra_config, *args, **kwargs)
|
721 | 727 | self.experts_min = self.orig_mod.experts_min
|
@@ -808,6 +814,76 @@ def extra_repr(self) -> str:
|
808 | 814 | )
|
809 | 815 |
|
810 | 816 |
|
| 817 | +class PatchedVllmMixtureOfExpertsOpV2(PatchedVllmMixtureOfExpertsOpV1): |
| 818 | + def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): |
| 819 | + PatchedModuleBase.__init__(self, mod, parent, mod_extra_config, *args, **kwargs) |
| 820 | + self.experts_min = self.orig_mod.experts_min |
| 821 | + self.experts_max = self.orig_mod.experts_max |
| 822 | + if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]: |
| 823 | + self.forward = self.forward_quant |
| 824 | + self.dynamic_moe_op = get_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE, self.scale_format) |
| 825 | + self.quant_input = self._mod_extra_config.inputs[0] |
| 826 | + self.register_scale("scale_input", mod_extra_config.scale.inputs[0], self.scale_format) |
| 827 | + # FIXME: (Yi) should we take the scale_intermediate from the mod_extra_config.scale.outputs? |
| 828 | + self.register_scale( |
| 829 | + "scale_intermediate", |
| 830 | + [mod_extra_config.scale.inputs[x] for x in range(1, self.num_experts+1)], |
| 831 | + self.scale_format, |
| 832 | + ) |
| 833 | + for i in range(self.num_experts): |
| 834 | + self.w1_list[i].weight.data = self.w1_list[i].weight.squeeze().t().contiguous() |
| 835 | + self.w3_list[i].weight.data = self.w3_list[i].weight.squeeze().t().contiguous() |
| 836 | + self.w2_list[i].weight.data = self.w2_list[i].weight.squeeze().t().contiguous() |
| 837 | + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): |
| 838 | + self.forward = self.forward_measure |
| 839 | + |
| 840 | + def post_process(self): |
| 841 | + self.dynamic_moe_op = get_quantized_func_wrapper(OP_TYPE.DYNAMIC_MOE_FUSED_WEIGHTS, self.scale_format) |
| 842 | + self.w13_weight = [] |
| 843 | + self.w2_weight = [] |
| 844 | + self.scale_w1 = [] |
| 845 | + self.scale_w2 = [] |
| 846 | + for i in range(self.num_experts): |
| 847 | + self.w2_weight.append(self.w2_list[i].weight.contiguous()) |
| 848 | + w1_weight = self.w1_list[i].weight |
| 849 | + w3_weight = self.w3_list[i].weight |
| 850 | + self.w13_weight.append(torch.cat((w1_weight, w3_weight), dim=0).contiguous()) |
| 851 | + # TODO: (Mengni) enhance the scale process for different scale formats |
| 852 | + self.scale_w1.append(max(self.w1_list[i].scale_weight, self.w3_list[i].scale_weight)) |
| 853 | + self.scale_w2.append(self.w2_list[i].scale_weight) |
| 854 | + delattr(self.w1_list[i], "weight") |
| 855 | + delattr(self.w3_list[i], "weight") |
| 856 | + delattr(self.w2_list[i], "weight") |
| 857 | + htcore.mark_step() |
| 858 | + delattr(self, "w1_list") |
| 859 | + delattr(self, "w3_list") |
| 860 | + delattr(self, "w2_list") |
| 861 | + |
| 862 | + def forward_quant(self, |
| 863 | + hidden_states, |
| 864 | + expert_routing_table, |
| 865 | + router_weights, |
| 866 | + permuted_weights=True, |
| 867 | + activation="silu"): |
| 868 | + qinput = self.quant_input(hidden_states) |
| 869 | + output = self.dynamic_moe_op( |
| 870 | + hidden_states=qinput, |
| 871 | + expert_routing_table=expert_routing_table, |
| 872 | + router_weights=router_weights, |
| 873 | + w12=self.w13_weight, |
| 874 | + w3=self.w2_weight, |
| 875 | + d_scale_w12=self.scale_w1, |
| 876 | + d_scale_w3=self.scale_w2, |
| 877 | + d_scale_hidden_states=self.scale_input, |
| 878 | + d_scale_intermediate_hidden_states=self.scale_intermediate, |
| 879 | + permuted_weights=permuted_weights, |
| 880 | + activation=activation, |
| 881 | + experts_min=self.experts_min, |
| 882 | + experts_max=self.experts_max, |
| 883 | + ) |
| 884 | + return output |
| 885 | + |
| 886 | + |
811 | 887 | class PatchedKVCache(PatchedModuleBase):
|
812 | 888 | # Module to patch KVCache module from llama model
|
813 | 889 | def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
|
|
0 commit comments