@@ -3581,3 +3581,72 @@ def __exit__(self, exc_type, exc_value, traceback):
3581
3581
for block in self ._model .blocks :
3582
3582
block .forward = block ._orig_forward
3583
3583
block .attn .forward = block .attn ._orig_forward
3584
+
3585
+
3586
+ # copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L321
3587
+ def _granite_moe_topk_gating_forward (self , hidden_states ):
3588
+ # compute the top_k routing decision
3589
+ logits = self .layer (hidden_states ).float () # [batch_size x seq_len, num_experts]
3590
+ top_k_logits , top_k_indices = logits .topk (self .top_k , dim = 1 ) # [num_tokens, top_k]
3591
+ top_k_gates = torch .softmax (top_k_logits , dim = 1 ).type_as (hidden_states ) # [num_tokens, top_k]
3592
+
3593
+ # compute number of input given to each expert
3594
+ zeros = torch .zeros (
3595
+ [top_k_gates .size (0 ), self .num_experts ], dtype = top_k_gates .dtype , device = top_k_gates .device
3596
+ ) # [num_tokens, num_experts]
3597
+ gates = zeros .scatter (1 , top_k_indices , 1 ) # [num_tokens, num_experts]
3598
+ expert_size = gates .long ().sum (0 ) # [num_experts,]
3599
+ # difference with original, removed expert_size = expert_size.tolist() due to incorrect tracing
3600
+
3601
+ # sort and group input tokens according to expert assignment
3602
+ top_k_experts = top_k_indices .flatten () # [num_tokens * top_k]
3603
+ _ , index_sorted_experts = top_k_experts .sort (0 ) # [num_tokens * top_k]
3604
+ batch_index = index_sorted_experts .div (self .top_k , rounding_mode = "trunc" ) # [num_tokens * top_k]
3605
+
3606
+ # gather the gate values for grouped input tokens
3607
+ top_k_gates = top_k_gates .flatten () # [num_tokens * top_k]
3608
+ batch_gates = top_k_gates [index_sorted_experts ] # [num_tokens * top_k]
3609
+
3610
+ return index_sorted_experts , batch_index , batch_gates , expert_size , logits
3611
+
3612
+
3613
+ # copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L281
3614
+ def _granite_moe_parallel_experts_forward (self , inputs , expert_size ):
3615
+ output_list = []
3616
+ # difference with original
3617
+ # 1) expert_size is tensor instead of list of ints after gating patching, that does not allow use original inputs.split(expert_size)
3618
+ # 2) use index_start:next_index for obtaining expert inputs splits one by one instead of precomputed splits once before cycle
3619
+ index_start = torch .tensor (0 , dtype = torch .int64 )
3620
+ for i in range (self .num_experts ):
3621
+ next_index = index_start + expert_size [i ]
3622
+ output_list .append (F .linear (inputs [index_start :next_index ], self .weight [i ]))
3623
+ index_start = next_index
3624
+ results = torch .cat (output_list , dim = 0 )
3625
+ return results
3626
+
3627
+
3628
+ class GraniteMoEModelPatcher (LlamaModelPatcher ):
3629
+ def __enter__ (self ):
3630
+ super ().__enter__ ()
3631
+ for layer in self ._model .model .layers :
3632
+ block_sparse_moe = layer .block_sparse_moe
3633
+ block_sparse_moe .router ._orig_forward = block_sparse_moe .router .forward
3634
+ block_sparse_moe .router .forward = types .MethodType (
3635
+ _granite_moe_topk_gating_forward , block_sparse_moe .router
3636
+ )
3637
+ block_sparse_moe .input_linear ._orig_forward = block_sparse_moe .input_linear .forward
3638
+ block_sparse_moe .input_linear .forward = types .MethodType (
3639
+ _granite_moe_parallel_experts_forward , block_sparse_moe .input_linear
3640
+ )
3641
+ block_sparse_moe .output_linear ._orig_forward = block_sparse_moe .output_linear .forward
3642
+ block_sparse_moe .output_linear .forward = types .MethodType (
3643
+ _granite_moe_parallel_experts_forward , block_sparse_moe .output_linear
3644
+ )
3645
+
3646
+ def __exit__ (self , exc_type , exc_value , traceback ):
3647
+ super ().__exit__ (exc_type , exc_value , traceback )
3648
+ for layer in self ._model .model .layers :
3649
+ block_sparse_moe = layer .block_sparse_moe
3650
+ block_sparse_moe .router .forward = block_sparse_moe .router ._orig_forward
3651
+ block_sparse_moe .input_linear .forward = block_sparse_moe .input_linear ._orig_forward
3652
+ block_sparse_moe .output_linear .forward = block_sparse_moe .output_linear ._orig_forward
0 commit comments