@@ -3575,6 +3575,301 @@ def __exit__(self, exc_type, exc_value, traceback):
3575
3575
block .self_attn .forward = block .self_attn ._orig_forward
3576
3576
3577
3577
3578
+ class DeepseekPatcher (DecoderModelPatcher ):
3579
+ def __enter__ (self ):
3580
+ super ().__enter__ ()
3581
+ self_attn = {
3582
+ "deepseek_v3" : deepseek_v3_attn_forward ,
3583
+ "deepseek_v2" : deepseek_v2_attn_forward ,
3584
+ "deepseek" : minicpm3_attn_forward ,
3585
+ }
3586
+
3587
+ self_attn_fwd = self_attn .get (self ._model .config .model_type )
3588
+ for block in self ._model .model .layers :
3589
+ if self_attn_fwd is not None :
3590
+ block .self_attn ._orig_forward = block .self_attn .forward
3591
+ block .self_attn .forward = types .MethodType (self_attn_fwd , block .self_attn )
3592
+ if hasattr (block .mlp , "moe_infer" ):
3593
+ block .mlp ._org_moe_infer = block .mlp .moe_infer
3594
+ block .mlp .moe_infer = types .MethodType (deepseek_moe_infer , block .mlp )
3595
+
3596
+ def __exit__ (self , exc_type , exc_value , traceback ):
3597
+ super ().__exit__ (exc_type , exc_value , traceback )
3598
+ for block in self ._model .model .layers :
3599
+ block .self_attn .forward = block .self_attn ._orig_forward
3600
+ if hasattr (block .mlp , "_orig_moe_infer" ):
3601
+ block .mlp .moe_infer = block .mlp ._orig_moe_infer
3602
+
3603
+
3604
+ def deepseek_v3_attn_forward (
3605
+ self ,
3606
+ hidden_states : torch .Tensor ,
3607
+ attention_mask : Optional [torch .Tensor ] = None ,
3608
+ position_ids : Optional [torch .LongTensor ] = None ,
3609
+ past_key_value = None ,
3610
+ output_attentions : bool = False ,
3611
+ use_cache : bool = False ,
3612
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
3613
+ # modified from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L751
3614
+ def rotate_half (x ):
3615
+ """Rotates half the hidden dims of the input."""
3616
+ x1 = x [..., : x .shape [- 1 ] // 2 ]
3617
+ x2 = x [..., x .shape [- 1 ] // 2 :]
3618
+ return torch .cat ((- x2 , x1 ), dim = - 1 )
3619
+
3620
+ def apply_rotary_pos_emb (q , k , cos , sin , position_ids , unsqueeze_dim = 1 ):
3621
+ orig_dtype = k .dtype
3622
+ cos = cos [position_ids ].unsqueeze (unsqueeze_dim ) # [bs, 1, seq_len, dim]
3623
+ sin = sin [position_ids ].unsqueeze (unsqueeze_dim ) # [bs, 1, seq_len, dim]
3624
+ q_fp32 = q .to (dtype = torch .float32 , device = q .device )
3625
+ k_fp32 = k .to (dtype = torch .float32 , device = k .device )
3626
+ q_embed = (q_fp32 * cos ) + (rotate_half (q_fp32 ) * sin )
3627
+ k_embed = (k_fp32 * cos ) + (rotate_half (k_fp32 ) * sin )
3628
+ return q_embed .to (dtype = orig_dtype ), k_embed .to (dtype = orig_dtype )
3629
+
3630
+ if output_attentions :
3631
+ return self ._orig_forward (
3632
+ hidden_states = hidden_states ,
3633
+ attention_mask = attention_mask ,
3634
+ position_ids = position_ids ,
3635
+ past_key_value = past_key_value ,
3636
+ output_attentions = output_attentions ,
3637
+ use_cache = use_cache ,
3638
+ )
3639
+
3640
+ bsz , q_len , _ = hidden_states .size ()
3641
+
3642
+ if self .q_lora_rank is None :
3643
+ q = self .q_proj (hidden_states )
3644
+ else :
3645
+ q = self .q_b_proj (self .q_a_layernorm (self .q_a_proj (hidden_states )))
3646
+ q = q .view (bsz , q_len , self .num_heads , self .q_head_dim ).transpose (1 , 2 )
3647
+ q_nope , q_pe = torch .split (q , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
3648
+
3649
+ compressed_kv = self .kv_a_proj_with_mqa (hidden_states )
3650
+ compressed_kv , k_pe = torch .split (compressed_kv , [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
3651
+ k_pe = k_pe .view (bsz , q_len , 1 , self .qk_rope_head_dim ).transpose (1 , 2 )
3652
+ kv = (
3653
+ self .kv_b_proj (self .kv_a_layernorm (compressed_kv ))
3654
+ .view (bsz , q_len , self .num_heads , self .qk_nope_head_dim + self .v_head_dim )
3655
+ .transpose (1 , 2 )
3656
+ )
3657
+
3658
+ k_nope , value_states = torch .split (kv , [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
3659
+ kv_seq_len = value_states .shape [- 2 ]
3660
+ if past_key_value is not None :
3661
+ if self .layer_idx is None :
3662
+ raise ValueError (
3663
+ f"The cache structure has changed since version v4.36. If you are using { self .__class__ .__name__ } "
3664
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
3665
+ "with a layer index."
3666
+ )
3667
+ kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
3668
+ cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
3669
+
3670
+ q_pe , k_pe = apply_rotary_pos_emb (q_pe , k_pe , cos , sin , position_ids )
3671
+
3672
+ # Difference with original code, k_pe.new_empty create constant tensor in torchscript
3673
+ query_states = torch .concat ([q_nope , q_pe ], dim = - 1 )
3674
+ # query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
3675
+ # query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
3676
+ # query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
3677
+ key_states = torch .concat ([k_nope , k_pe .expand (- 1 , self .num_heads , - 1 , - 1 )], dim = - 1 )
3678
+ # key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
3679
+ # key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
3680
+ # key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
3681
+ if past_key_value is not None :
3682
+ cache_kwargs = {"sin" : sin , "cos" : cos } # Specific to RoPE models
3683
+ key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
3684
+
3685
+ if attention_mask is not None :
3686
+ if attention_mask .size () != (bsz , 1 , q_len , kv_seq_len ):
3687
+ raise ValueError (
3688
+ f"Attention mask should be of size { (bsz , 1 , q_len , kv_seq_len )} , but is { attention_mask .size ()} "
3689
+ )
3690
+
3691
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
3692
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
3693
+ if query_states .device .type == "cuda" and attention_mask is not None :
3694
+ query_states = query_states .contiguous ()
3695
+ key_states = key_states .contiguous ()
3696
+ value_states = value_states .contiguous ()
3697
+
3698
+ attn_output = torch .nn .functional .scaled_dot_product_attention (
3699
+ query_states ,
3700
+ key_states ,
3701
+ value_states ,
3702
+ attn_mask = attention_mask ,
3703
+ dropout_p = self .attention_dropout if self .training else 0.0 ,
3704
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
3705
+ is_causal = self .is_causal and attention_mask is None and q_len > 1 ,
3706
+ )
3707
+
3708
+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
3709
+
3710
+ attn_output = attn_output .reshape (bsz , q_len , self .num_heads * self .v_head_dim )
3711
+
3712
+ attn_output = self .o_proj (attn_output )
3713
+
3714
+ return attn_output , None , past_key_value
3715
+
3716
+
3717
+ def deepseek_v2_attn_forward (
3718
+ self ,
3719
+ hidden_states : torch .Tensor ,
3720
+ attention_mask : Optional [torch .Tensor ] = None ,
3721
+ position_ids : Optional [torch .LongTensor ] = None ,
3722
+ past_key_value = None ,
3723
+ output_attentions : bool = False ,
3724
+ use_cache : bool = False ,
3725
+ ** kwargs ,
3726
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
3727
+ # modified from https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py#L806
3728
+ def rotate_half (x ):
3729
+ """Rotates half the hidden dims of the input."""
3730
+ x1 = x [..., : x .shape [- 1 ] // 2 ]
3731
+ x2 = x [..., x .shape [- 1 ] // 2 :]
3732
+ return torch .cat ((- x2 , x1 ), dim = - 1 )
3733
+
3734
+ def apply_rotary_pos_emb (q , k , cos , sin , position_ids , unsqueeze_dim = 1 ):
3735
+ cos = cos [position_ids ].unsqueeze (unsqueeze_dim )
3736
+ sin = sin [position_ids ].unsqueeze (unsqueeze_dim )
3737
+
3738
+ b , h , s , d = q .shape
3739
+ q = q .view (b , h , s , d // 2 , 2 ).transpose (4 , 3 ).reshape (b , h , s , d )
3740
+
3741
+ b , h , s , d = k .shape
3742
+ k = k .view (b , h , s , d // 2 , 2 ).transpose (4 , 3 ).reshape (b , h , s , d )
3743
+
3744
+ q_embed = (q * cos ) + (rotate_half (q ) * sin )
3745
+ k_embed = (k * cos ) + (rotate_half (k ) * sin )
3746
+ return q_embed , k_embed
3747
+
3748
+ if output_attentions :
3749
+ return self ._orig_forward (
3750
+ hidden_states = hidden_states ,
3751
+ attention_mask = attention_mask ,
3752
+ position_ids = position_ids ,
3753
+ past_key_value = past_key_value ,
3754
+ output_attentions = output_attentions ,
3755
+ use_cache = use_cache ,
3756
+ )
3757
+
3758
+ bsz , q_len , _ = hidden_states .shape
3759
+
3760
+ if self .q_lora_rank is None :
3761
+ q = self .q_proj (hidden_states )
3762
+ else :
3763
+ q = self .q_b_proj (self .q_a_layernorm (self .q_a_proj (hidden_states )))
3764
+ q = q .view (bsz , q_len , self .num_heads , self .q_head_dim ).transpose (1 , 2 )
3765
+ q_nope , q_pe = torch .split (q , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
3766
+
3767
+ compressed_kv = self .kv_a_proj_with_mqa (hidden_states )
3768
+ compressed_kv , k_pe = torch .split (compressed_kv , [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
3769
+ k_pe = k_pe .view (bsz , q_len , 1 , self .qk_rope_head_dim ).transpose (1 , 2 )
3770
+ kv = (
3771
+ self .kv_b_proj (self .kv_a_layernorm (compressed_kv ))
3772
+ .view (bsz , q_len , self .num_heads , self .qk_nope_head_dim + self .v_head_dim )
3773
+ .transpose (1 , 2 )
3774
+ )
3775
+
3776
+ k_nope , value_states = torch .split (kv , [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
3777
+ kv_seq_len = value_states .shape [- 2 ]
3778
+ if past_key_value is not None :
3779
+ if self .layer_idx is None :
3780
+ raise ValueError (
3781
+ f"The cache structure has changed since version v4.36. If you are using { self .__class__ .__name__ } "
3782
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
3783
+ "with a layer index."
3784
+ )
3785
+ kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
3786
+ cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
3787
+
3788
+ q_pe , k_pe = apply_rotary_pos_emb (q_pe , k_pe , cos , sin , position_ids )
3789
+
3790
+ # Difference with original code, k_pe.new_empty create constant tensor in torchscript
3791
+ query_states = torch .concat ([q_nope , q_pe ], dim = - 1 )
3792
+ # query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
3793
+ # query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
3794
+ # query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
3795
+ key_states = torch .concat ([k_nope , k_pe .expand (- 1 , self .num_heads , - 1 , - 1 )], dim = - 1 )
3796
+ # key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
3797
+ # key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
3798
+ # key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
3799
+ if past_key_value is not None :
3800
+ cache_kwargs = {"sin" : sin , "cos" : cos } # Specific to RoPE models
3801
+ key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
3802
+
3803
+ if attention_mask is not None :
3804
+ if attention_mask .size () != (bsz , 1 , q_len , kv_seq_len ):
3805
+ raise ValueError (
3806
+ f"Attention mask should be of size { (bsz , 1 , q_len , kv_seq_len )} , but is { attention_mask .size ()} "
3807
+ )
3808
+
3809
+ if attention_mask is not None :
3810
+ if attention_mask .size () != (bsz , 1 , q_len , kv_seq_len ):
3811
+ raise ValueError (
3812
+ f"Attention mask should be of size { (bsz , 1 , q_len , kv_seq_len )} , but is { attention_mask .size ()} "
3813
+ )
3814
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
3815
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
3816
+ if query_states .device .type == "cuda" and attention_mask is not None :
3817
+ query_states = query_states .contiguous ()
3818
+ key_states = key_states .contiguous ()
3819
+ value_states = value_states .contiguous ()
3820
+
3821
+ attn_output = torch .nn .functional .scaled_dot_product_attention (
3822
+ query_states ,
3823
+ key_states ,
3824
+ value_states ,
3825
+ attn_mask = attention_mask ,
3826
+ dropout_p = self .attention_dropout if self .training else 0.0 ,
3827
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
3828
+ is_causal = self .is_causal and attention_mask is None and q_len > 1 ,
3829
+ )
3830
+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
3831
+
3832
+ attn_output = attn_output .reshape (bsz , q_len , self .num_heads * self .v_head_dim )
3833
+
3834
+ attn_output = self .o_proj (attn_output )
3835
+
3836
+ return attn_output , None , past_key_value
3837
+
3838
+
3839
+ def deepseek_moe_infer (self , x , topk_ids , topk_weight ):
3840
+ cnts = torch .zeros ((topk_ids .shape [0 ], len (self .experts )))
3841
+ cnts .scatter_ (1 , topk_ids , 1 )
3842
+ tokens_per_expert = cnts .sum (dim = 0 ).to (torch .long )
3843
+ idxs = torch .argsort (topk_ids .view (- 1 ))
3844
+ sorted_tokens = x [idxs // topk_ids .shape [1 ]]
3845
+
3846
+ outputs = []
3847
+ start_idx = torch .tensor (0 , dtype = torch .long )
3848
+ for i , num_tokens in enumerate (tokens_per_expert ):
3849
+ end_idx = start_idx + num_tokens
3850
+ # difference with original: removed skiping expert if empty num_tokens
3851
+ expert_id = i + self .ep_rank * self .experts_per_rank
3852
+ expert = self .experts [expert_id ]
3853
+ tokens_for_this_expert = sorted_tokens [start_idx :end_idx ]
3854
+ expert_out = expert (tokens_for_this_expert )
3855
+ outputs .append (expert_out )
3856
+ start_idx = end_idx
3857
+
3858
+ # difference with original: removed usage torch.new_empty if outputs empty
3859
+ outs = torch .cat (outputs , dim = 0 )
3860
+
3861
+ new_x = torch .zeros_like (outs )
3862
+ new_x [idxs ] = outs
3863
+ final_out = (
3864
+ new_x .view (* topk_ids .shape , - 1 )
3865
+ .to (topk_weight .dtype )
3866
+ .mul_ (topk_weight .unsqueeze (dim = - 1 ))
3867
+ .sum (dim = 1 )
3868
+ .to (new_x .dtype )
3869
+ )
3870
+ return final_out
3871
+
3872
+
3578
3873
class Qwen2VLLanguageModelPatcher (DecoderModelPatcher ):
3579
3874
def __init__ (
3580
3875
self ,
0 commit comments