@@ -3601,3 +3601,70 @@ def __exit__(self, exc_type, exc_value, traceback):
3601
3601
for block in self ._model .blocks :
3602
3602
block .forward = block ._orig_forward
3603
3603
block .attn .forward = block .attn ._orig_forward
3604
+
3605
+
3606
+ def _granite_moe_topk_gating_forward (self , hidden_states ):
3607
+ # compute the top_k routing decision
3608
+ logits = self .layer (hidden_states ).float () # [batch_size x seq_len, num_experts]
3609
+ top_k_logits , top_k_indices = logits .topk (self .top_k , dim = 1 ) # [num_tokens, top_k]
3610
+ top_k_gates = torch .softmax (top_k_logits , dim = 1 ).type_as (hidden_states ) # [num_tokens, top_k]
3611
+
3612
+ # compute number of input given to each expert
3613
+ zeros = torch .zeros (
3614
+ [top_k_gates .size (0 ), self .num_experts ], dtype = top_k_gates .dtype , device = top_k_gates .device
3615
+ ) # [num_tokens, num_experts]
3616
+ gates = zeros .scatter (1 , top_k_indices , 1 ) # [num_tokens, num_experts]
3617
+ expert_size = gates .long ().sum (0 ) # [num_experts,]
3618
+ # difference with original, removed expert_size = expert_size.tolist() due to incorrect tracing
3619
+
3620
+ # sort and group input tokens according to expert assignment
3621
+ top_k_experts = top_k_indices .flatten () # [num_tokens * top_k]
3622
+ _ , index_sorted_experts = top_k_experts .sort (0 ) # [num_tokens * top_k]
3623
+ batch_index = index_sorted_experts .div (self .top_k , rounding_mode = "trunc" ) # [num_tokens * top_k]
3624
+
3625
+ # gather the gate values for grouped input tokens
3626
+ top_k_gates = top_k_gates .flatten () # [num_tokens * top_k]
3627
+ batch_gates = top_k_gates [index_sorted_experts ] # [num_tokens * top_k]
3628
+
3629
+ return index_sorted_experts , batch_index , batch_gates , expert_size , logits
3630
+
3631
+
3632
+ def _granite_moe_parallel_experts_forward (self , inputs , expert_size ):
3633
+ output_list = []
3634
+ # difference with original
3635
+ # 1) expert_size is tensor instead of list of ints after gating patching, that does not allow use original inputs.split(expert_size)
3636
+ # 2) use index_start:next_index for obtaining expert inputs splits one by one instead of precomputed splits once before cycle
3637
+ index_start = torch .tensor (0 , dtype = torch .int64 )
3638
+ for i in range (self .num_experts ):
3639
+ next_index = index_start + expert_size [i ]
3640
+ output_list .append (F .linear (inputs [index_start :next_index ], self .weight [i ]))
3641
+ index_start = next_index
3642
+ results = torch .cat (output_list , dim = 0 )
3643
+ return results
3644
+
3645
+
3646
+ class GraniteMoEModelPatcher (LlamaModelPatcher ):
3647
+ def __enter__ (self ):
3648
+ super ().__enter__ ()
3649
+ for layer in self ._model .model .layers :
3650
+ block_sparse_moe = layer .block_sparse_moe
3651
+ block_sparse_moe .router ._orig_forward = block_sparse_moe .router .forward
3652
+ block_sparse_moe .router .forward = types .MethodType (
3653
+ _granite_moe_topk_gating_forward , block_sparse_moe .router
3654
+ )
3655
+ block_sparse_moe .input_linear ._orig_forward = block_sparse_moe .input_linear .forward
3656
+ block_sparse_moe .input_linear .forward = types .MethodType (
3657
+ _granite_moe_parallel_experts_forward , block_sparse_moe .input_linear
3658
+ )
3659
+ block_sparse_moe .output_linear ._orig_forward = block_sparse_moe .output_linear .forward
3660
+ block_sparse_moe .output_linear .forward = types .MethodType (
3661
+ _granite_moe_parallel_experts_forward , block_sparse_moe .output_linear
3662
+ )
3663
+
3664
+ def __exit__ (self , exc_type , exc_value , traceback ):
3665
+ super ().__exit__ (exc_type , exc_value , traceback )
3666
+ for layer in self ._model .model .layers :
3667
+ block_sparse_moe = layer .block_sparse_moe
3668
+ block_sparse_moe .router .forward = block_sparse_moe .router ._orig_forward
3669
+ block_sparse_moe .input_linear .forward = block_sparse_moe .input_linear ._orig_forward
3670
+ block_sparse_moe .output_linear .forward = block_sparse_moe .output_linear ._orig_forward
0 commit comments