@@ -1356,3 +1356,62 @@ def __exit__(self, exc_type, exc_value, traceback):
1356
1356
for layer in self ._model .transformer .h :
1357
1357
if hasattr (layer .attn , "_orig_attn" ):
1358
1358
layer .attn ._attn = layer .attn ._orig_attn
1359
+
1360
+
1361
+ def _dbrx_experts_forward (
1362
+ self , x : torch .Tensor , weights : torch .Tensor , top_weights : torch .Tensor , top_experts : torch .LongTensor
1363
+ ):
1364
+ bsz , q_len , hidden_size = x .shape
1365
+ x = x .view (- 1 , hidden_size )
1366
+ out = torch .zeros_like (x )
1367
+
1368
+ expert_mask = torch .nn .functional .one_hot (top_experts , num_classes = self .moe_num_experts ).permute (2 , 1 , 0 )
1369
+ # Chunk experts at once to avoid storing full parameter multiple times in autograd
1370
+ w1_chunked = self .mlp .w1 .view (self .mlp .moe_num_experts , self .mlp .ffn_hidden_size , self .mlp .hidden_size ).chunk (
1371
+ self .moe_num_experts , dim = 0
1372
+ )
1373
+ v1_chunked = self .mlp .v1 .view (self .mlp .moe_num_experts , self .mlp .ffn_hidden_size , self .mlp .hidden_size ).chunk (
1374
+ self .moe_num_experts , dim = 0
1375
+ )
1376
+ w2_chunked = self .mlp .w2 .view (self .mlp .moe_num_experts , self .mlp .ffn_hidden_size , self .mlp .hidden_size ).chunk (
1377
+ self .moe_num_experts , dim = 0
1378
+ )
1379
+ w1_chunked = [w1 .squeeze (dim = 0 ) for w1 in w1_chunked ]
1380
+ v1_chunked = [v1 .squeeze (dim = 0 ) for v1 in v1_chunked ]
1381
+ w2_chunked = [w2 .squeeze (dim = 0 ) for w2 in w2_chunked ]
1382
+ for expert_idx in range (0 , self .moe_num_experts ):
1383
+ topk_idx , token_idx = torch .where (expert_mask [expert_idx ])
1384
+
1385
+ token_list = token_idx
1386
+ topk_list = topk_idx
1387
+
1388
+ expert_tokens = x [None , token_list ].reshape (- 1 , hidden_size )
1389
+ expert_out = (
1390
+ self .mlp (expert_tokens , w1_chunked [expert_idx ], v1_chunked [expert_idx ], w2_chunked [expert_idx ])
1391
+ * top_weights [token_list , topk_list , None ]
1392
+ )
1393
+
1394
+ out .index_add_ (0 , token_idx , expert_out )
1395
+
1396
+ out = out .reshape (bsz , q_len , hidden_size )
1397
+ return out
1398
+
1399
+
1400
+ class DBRXModelPatcher (DecoderModelPatcher ):
1401
+ def __enter__ (self ):
1402
+ super ().__enter__ ()
1403
+
1404
+ for block in self ._model .transformer .blocks :
1405
+ rotary_emb = block .norm_attn_norm .attn .rotary_emb
1406
+ if rotary_emb .inv_freq is None :
1407
+ inv_freq = 1.0 / (
1408
+ rotary_emb .base ** (torch .arange (0 , rotary_emb .dim , 2 , dtype = torch .int64 ).float () / rotary_emb .dim )
1409
+ )
1410
+ rotary_emb .inv_freq = inv_freq
1411
+ block .ffn .experts ._orig_forward = block .ffn .experts .forward
1412
+ block .ffn .experts .forward = types .MethodType (_dbrx_experts_forward , block .ffn .experts )
1413
+
1414
+ def __exit__ (self , exc_type , exc_value , traceback ):
1415
+ super ().__exit__ (exc_type , exc_value , traceback )
1416
+ for block in self ._model .transformer .blocks :
1417
+ block .ffn .experts .forward = block .ffn .experts ._orig_forward
0 commit comments