33
33
34
34
import os
35
35
VLLM_LOAD_FOR_INC = os .environ .get ("VLLM_LOAD_FOR_INC" , "0" ) == "1"
36
+ LOW_CPU_MEM = os .environ .get ("LOW_CPU_MEM" , "0" ) == "1"
36
37
37
38
# ==-------------------------------------------------------------------------==
38
39
# VLLM-HPU-EXT PATCH Start
42
43
import habana_frameworks .torch .core as htcore
43
44
import habana_frameworks .torch as htorch
44
45
45
-
46
46
class MoeMatmul (torch .nn .Module ):
47
47
def __init__ (self ):
48
48
super ().__init__ ()
@@ -57,9 +57,17 @@ def forward(self, state, expert_id, w):
57
57
class VllmMixtureOfExpertsOp (torch .nn .Module ):
58
58
def __init__ (self , num_total_experts ):
59
59
super ().__init__ ()
60
- self .w13_list = torch .nn .ModuleList (
61
- [MoeMatmul () for _ in range (num_total_experts )]
62
- )
60
+ if not LOW_CPU_MEM :
61
+ self .w13_list = torch .nn .ModuleList (
62
+ [MoeMatmul () for _ in range (num_total_experts )]
63
+ )
64
+ else :
65
+ self .w1_list = torch .nn .ModuleList (
66
+ [MoeMatmul () for _ in range (num_total_experts )]
67
+ )
68
+ self .w3_list = torch .nn .ModuleList (
69
+ [MoeMatmul () for _ in range (num_total_experts )]
70
+ )
63
71
self .w2_list = torch .nn .ModuleList (
64
72
[MoeMatmul () for _ in range (num_total_experts )]
65
73
)
@@ -82,20 +90,36 @@ def forward(
82
90
assert self .experts_min is not None , "`experts_min` is not set"
83
91
assert self .experts_max is not None , "`experts_max` is not set"
84
92
experts_min , experts_max = self .experts_min , self .experts_max
85
- w1_list = [self .w13_list [i ].weight .squeeze () for i in experts_range ]
86
- w2_list = [self .w2_list [i ].weight .squeeze () for i in experts_range ]
87
- return torch .ops .hpu .mixture_of_experts (
88
- hidden_states = hidden_states ,
89
- expert_routing_table = expert_routing_table ,
90
- router_weights = router_weights ,
91
- w12 = w1_list ,
92
- w3 = w2_list ,
93
- permuted_weights = permuted_weights ,
94
- activation = activation ,
95
- experts_min = experts_min ,
96
- experts_max = experts_max ,
97
- )
98
-
93
+ if not LOW_CPU_MEM :
94
+ w1_list = [self .w13_list [i ].weight .squeeze () for i in experts_range ]
95
+ w2_list = [self .w2_list [i ].weight .squeeze () for i in experts_range ]
96
+ return torch .ops .hpu .mixture_of_experts .fused_weights (
97
+ hidden_states = hidden_states ,
98
+ expert_routing_table = expert_routing_table ,
99
+ router_weights = router_weights ,
100
+ w12 = w1_list ,
101
+ w3 = w2_list ,
102
+ permuted_weights = permuted_weights ,
103
+ activation = activation ,
104
+ experts_min = experts_min ,
105
+ experts_max = experts_max ,
106
+ )
107
+ else :
108
+ w1_list = [self .w1_list [i ].weight .squeeze () for i in experts_range ]
109
+ w3_list = [self .w3_list [i ].weight .squeeze () for i in experts_range ]
110
+ w2_list = [self .w2_list [i ].weight .squeeze () for i in experts_range ]
111
+ return torch .ops .hpu .mixture_of_experts .default (
112
+ hidden_states = hidden_states ,
113
+ expert_routing_table = expert_routing_table ,
114
+ router_weights = router_weights ,
115
+ w1 = w1_list ,
116
+ w2 = w2_list ,
117
+ w3 = w3_list ,
118
+ permuted_weights = permuted_weights ,
119
+ activation = activation ,
120
+ experts_min = experts_min ,
121
+ experts_max = experts_max ,
122
+ )
99
123
100
124
class _DynamicFusedMOE (torch .nn .Module ):
101
125
def __init__ (self , num_total_experts ):
@@ -165,14 +189,33 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
165
189
hidden_size : int , intermediate_size_per_partition : int ,
166
190
params_dtype : torch .dtype , ** extra_weight_attrs ):
167
191
# Fused gate_up_proj (column parallel)
168
- w13_weight = torch .nn .Parameter (torch .empty (
169
- num_experts ,
170
- 2 * intermediate_size_per_partition ,
171
- hidden_size ,
172
- dtype = params_dtype ),
173
- requires_grad = False )
174
- layer .register_parameter ("w13_weight" , w13_weight )
175
- set_weight_attrs (w13_weight , extra_weight_attrs )
192
+ if not LOW_CPU_MEM :
193
+ w13_weight = torch .nn .Parameter (torch .empty (
194
+ num_experts ,
195
+ 2 * intermediate_size_per_partition ,
196
+ hidden_size ,
197
+ dtype = params_dtype ),
198
+ requires_grad = False )
199
+ layer .register_parameter ("w13_weight" , w13_weight )
200
+ set_weight_attrs (w13_weight , extra_weight_attrs )
201
+ else :
202
+ w1_weight = torch .nn .Parameter (torch .empty (
203
+ num_experts ,
204
+ intermediate_size_per_partition ,
205
+ hidden_size ,
206
+ dtype = params_dtype ),
207
+ requires_grad = False )
208
+ layer .register_parameter ("w1_weight" , w1_weight )
209
+ set_weight_attrs (w1_weight , extra_weight_attrs )
210
+
211
+ w3_weight = torch .nn .Parameter (torch .empty (
212
+ num_experts ,
213
+ intermediate_size_per_partition ,
214
+ hidden_size ,
215
+ dtype = params_dtype ),
216
+ requires_grad = False )
217
+ layer .register_parameter ("w3_weight" , w3_weight )
218
+ set_weight_attrs (w3_weight , extra_weight_attrs )
176
219
177
220
# down_proj (row parallel)
178
221
w2_weight = torch .nn .Parameter (torch .empty (
@@ -327,8 +370,8 @@ def forward_hpu_original(
327
370
max_expert = (i + 1 ) * n_expert_slice
328
371
# w13_list_slice = [w13_list[i].weight.squeeze() for i in range(min_expert, max_expert)]
329
372
# w2_list_slice = [w2_list[i].weight.squeeze() for i in range(min_expert, max_expert)]
330
- w13_list_slice = [layer .w13_weight [j ].squeeze (). clone () for j in range (min_expert , max_expert )]
331
- w2_list_slice = [layer .w2_weight [j ].squeeze (). clone () for j in range (min_expert , max_expert )]
373
+ w13_list_slice = [layer .w13_weight [j ].squeeze () for j in range (min_expert , max_expert )]
374
+ w2_list_slice = [layer .w2_weight [j ].squeeze () for j in range (min_expert , max_expert )]
332
375
final_hidden_states += torch .ops .hpu .mixture_of_experts (hidden_states = x ,
333
376
expert_routing_table = topk_ids .to (torch .int64 ),
334
377
router_weights = topk_weights .to (x .dtype ),
@@ -612,21 +655,54 @@ def __init__(
612
655
max_expert = (i + 1 ) * n_expert_slice
613
656
# Note: clone weight will cause OoM.
614
657
# rank_debug(f"i:{i}, num_experts:{num_experts} loading experts from {min_expert} to {max_expert}, layer.w13_weight.shape : {layer.w13_weight.shape}")
615
- w13_list_slice = [
616
- layer .w13_weight [j ]
617
- for j in range (min_expert , max_expert )
618
- ]
658
+ if not LOW_CPU_MEM :
659
+ w13_list_slice = [
660
+ layer .w13_weight [j ]
661
+ for j in range (min_expert , max_expert )
662
+ ]
663
+ else :
664
+ w1_list_slice = [
665
+ layer .w1_weight [j ]
666
+ for j in range (min_expert , max_expert )
667
+ ]
668
+ w3_list_slice = [
669
+ layer .w3_weight [j ]
670
+ for j in range (min_expert , max_expert )
671
+ ]
619
672
w2_list_slice = [
620
673
layer .w2_weight [j ]
621
674
for j in range (min_expert , max_expert )
622
675
]
623
- for index in range (len (w13_list_slice )):
624
- _temp_expert_group .MoeOp .w13_list [index ].set_weight (
625
- w13_list_slice [index ]
626
- )
627
- _temp_expert_group .MoeOp .w2_list [index ].set_weight (
628
- w2_list_slice [index ]
629
- )
676
+ for index in range (len (w2_list_slice )):
677
+ if not LOW_CPU_MEM :
678
+ _temp_expert_group .MoeOp .w13_list [index ].set_weight (
679
+ w13_list_slice [index ]
680
+ )
681
+ _temp_expert_group .MoeOp .w2_list [index ].set_weight (
682
+ w2_list_slice [index ]
683
+ )
684
+ else :
685
+ _temp_expert_group .MoeOp .w1_list [index ].set_weight (
686
+ torch .nn .Parameter (torch .empty (
687
+ w1_list_slice [index ].shape ,
688
+ dtype = w1_list_slice [index ].dtype ,
689
+ ),
690
+ requires_grad = False )
691
+ )
692
+ _temp_expert_group .MoeOp .w3_list [index ].set_weight (
693
+ torch .nn .Parameter (torch .empty (
694
+ w3_list_slice [index ].shape ,
695
+ dtype = w3_list_slice [index ].dtype ,
696
+ ),
697
+ requires_grad = False )
698
+ )
699
+ _temp_expert_group .MoeOp .w2_list [index ].set_weight (
700
+ torch .nn .Parameter (torch .empty (
701
+ w2_list_slice [index ].shape ,
702
+ dtype = w2_list_slice [index ].dtype ,
703
+ ),
704
+ requires_grad = False )
705
+ )
630
706
# FIXME: (Yi) pass `experts_min` and `experts_max` to MoeOp.
631
707
setattr (_temp_expert_group .MoeOp , "experts_min" , min_expert + ep_shift )
632
708
setattr (_temp_expert_group .MoeOp , "experts_max" , max_expert - 1 + ep_shift )
@@ -891,13 +967,31 @@ def weight_loader(self, param: torch.nn.Parameter,
891
967
892
968
# Case model weights
893
969
if "weight" in weight_name :
894
- self ._load_model_weight_or_group_weight_scale (
895
- shard_id = shard_id ,
896
- shard_dim = shard_dim ,
897
- loaded_weight = loaded_weight ,
898
- expert_data = expert_data ,
899
- tp_rank = tp_rank ,
900
- expert_id = expert_id )
970
+ if not LOW_CPU_MEM :
971
+ self ._load_model_weight_or_group_weight_scale (
972
+ shard_id = shard_id ,
973
+ shard_dim = shard_dim ,
974
+ loaded_weight = loaded_weight ,
975
+ expert_data = expert_data ,
976
+ tp_rank = tp_rank ,
977
+ expert_id = expert_id )
978
+ else :
979
+ expert_group = expert_id // (self .num_experts // self .num_expert_group )
980
+ lst_id = expert_id % (self .num_experts // self .num_expert_group )
981
+ expert = getattr (self , f"_temp_expert_group_{ expert_group } " )
982
+ if shard_id in ("w1" , "w3" ):
983
+ if shard_id == "w1" :
984
+ expert_data = expert .MoeOp .w1_list [lst_id ].weight
985
+ expert_data .data = loaded_weight
986
+ else :
987
+ assert shard_id == "w3"
988
+ expert_data = expert .MoeOp .w3_list [lst_id ].weight
989
+ expert_data .data = loaded_weight
990
+
991
+ elif shard_id == "w2" :
992
+ # load w2
993
+ expert_data = expert .MoeOp .w2_list [lst_id ].weight
994
+ expert_data .data = loaded_weight
901
995
return
902
996
903
997
@staticmethod
@@ -973,14 +1067,19 @@ def make_expert_params_mapping(
973
1067
ckpt_up_proj_name : str ,
974
1068
num_experts : int ) -> List [Tuple [str , str , int , str ]]:
975
1069
1070
+ mapping = {
1071
+ ckpt_gate_proj_name : "experts.w1_" if LOW_CPU_MEM else "experts.w13_" ,
1072
+ ckpt_down_proj_name : "experts.w2_" ,
1073
+ ckpt_up_proj_name : "experts.w3_" if LOW_CPU_MEM else "experts.w13_" ,
1074
+ }
976
1075
return [
977
1076
# (param_name, weight_name, expert_id, shard_id)
978
- ("experts.w13_" if weight_name
979
- in [ckpt_gate_proj_name , ckpt_up_proj_name ] else "experts.w2_" ,
1077
+ (mapping [weight_name ],
980
1078
f"experts.{ expert_id } .{ weight_name } ." , expert_id , shard_id )
981
1079
for expert_id in range (num_experts ) for shard_id , weight_name in [
982
1080
("w1" , ckpt_gate_proj_name ),
983
1081
("w2" , ckpt_down_proj_name ),
984
1082
("w3" , ckpt_up_proj_name ),
985
1083
]
986
1084
]
1085
+
0 commit comments