@@ -3614,29 +3614,8 @@ def __exit__(self, exc_type, exc_value, traceback):
3614
3614
self ._model .forward = self ._model .__orig_forward
3615
3615
3616
3616
3617
- class Qwen2VLVisionEmbMergerPatcher (ModelPatcher ):
3618
- def __init__ (
3619
- self ,
3620
- config : "OnnxConfig" ,
3621
- model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
3622
- model_kwargs : Dict [str , Any ] = None ,
3623
- ):
3624
- model .__orig_forward = model .forward
3625
-
3626
- # Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1118
3627
- # added attention_mask input instead cu_lens for its internal calculation model (unsupported by tracing due to cycle with dynamic len)
3628
- # separated patch_embed and rot_pos_emb calls for performing as part of another model
3629
- def image_embed_forward (
3630
- self , hidden_states : torch .Tensor , attention_mask : torch .Tensor , rotary_pos_emb : torch .Tensor
3631
- ) -> torch .Tensor :
3632
- for blk in self .blocks :
3633
- hidden_states = blk (hidden_states , attention_mask = attention_mask , rotary_pos_emb = rotary_pos_emb )
3634
- return self .merger (hidden_states )
3635
-
3636
- model .forward = types .MethodType (image_embed_forward , model )
3637
- super ().__init__ (config , model , model_kwargs )
3638
-
3639
- def __enter__ (self ):
3617
+ def patch_qwen2vl_vision_blocks (model , force_new_behaviour = False ):
3618
+ if not force_new_behaviour and is_transformers_version ("<=" , "4.48.99" ):
3640
3619
# Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L390
3641
3620
# added attention_mask input instead of internal calculation (unsupported by tracing due to cycle with dynamic len)
3642
3621
def sdpa_attn_forward (
@@ -3667,11 +3646,161 @@ def block_forward(self, hidden_states, attention_mask, rotary_pos_emb) -> torch.
3667
3646
hidden_states = hidden_states + self .mlp (self .norm2 (hidden_states ))
3668
3647
return hidden_states
3669
3648
3649
+ else :
3650
+
3651
+ def sdpa_attn_forward (
3652
+ self ,
3653
+ hidden_states : torch .Tensor ,
3654
+ attention_mask : torch .Tensor ,
3655
+ rotary_pos_emb : torch .Tensor = None ,
3656
+ position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
3657
+ ):
3658
+ def rotate_half (x ):
3659
+ """Rotates half the hidden dims of the input."""
3660
+ x1 = x [..., : x .shape [- 1 ] // 2 ]
3661
+ x2 = x [..., x .shape [- 1 ] // 2 :]
3662
+ return torch .cat ((- x2 , x1 ), dim = - 1 )
3663
+
3664
+ def apply_rotary_pos_emb_vision (
3665
+ q : torch .Tensor , k : torch .Tensor , cos : torch .Tensor , sin : torch .Tensor
3666
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
3667
+ orig_q_dtype = q .dtype
3668
+ orig_k_dtype = k .dtype
3669
+ q , k = q .float (), k .float ()
3670
+ cos , sin = cos .unsqueeze (- 2 ), sin .unsqueeze (- 2 )
3671
+ q_embed = (q * cos ) + (rotate_half (q ) * sin )
3672
+ k_embed = (k * cos ) + (rotate_half (k ) * sin )
3673
+ q_embed = q_embed .to (orig_q_dtype )
3674
+ k_embed = k_embed .to (orig_k_dtype )
3675
+ return q_embed , k_embed
3676
+
3677
+ seq_length = hidden_states .shape [0 ]
3678
+ q , k , v = self .qkv (hidden_states ).reshape (seq_length , 3 , self .num_heads , - 1 ).permute (1 , 0 , 2 , 3 ).unbind (0 )
3679
+ if position_embeddings is None :
3680
+ emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
3681
+ cos = emb .cos ().float ()
3682
+ sin = emb .sin ().float ()
3683
+ else :
3684
+ cos , sin = position_embeddings
3685
+ q , k = apply_rotary_pos_emb_vision (q , k , cos , sin )
3686
+ q = q .transpose (0 , 1 )
3687
+ k = k .transpose (0 , 1 )
3688
+ v = v .transpose (0 , 1 )
3689
+ attn_output = torch .nn .functional .scaled_dot_product_attention (q , k , v , attention_mask , dropout_p = 0.0 )
3690
+ attn_output = attn_output .transpose (0 , 1 )
3691
+ attn_output = attn_output .reshape (seq_length , - 1 )
3692
+ attn_output = self .proj (attn_output )
3693
+ return attn_output
3694
+
3695
+ def block_forward (
3696
+ self ,
3697
+ hidden_states ,
3698
+ attention_mask ,
3699
+ rotary_pos_emb : Optional [torch .Tensor ] = None ,
3700
+ position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
3701
+ ) -> torch .Tensor :
3702
+ hidden_states = hidden_states + self .attn (
3703
+ self .norm1 (hidden_states ),
3704
+ attention_mask = attention_mask ,
3705
+ rotary_pos_emb = rotary_pos_emb ,
3706
+ position_embeddings = position_embeddings ,
3707
+ )
3708
+ hidden_states = hidden_states + self .mlp (self .norm2 (hidden_states ))
3709
+ return hidden_states
3710
+
3711
+ for block in model .blocks :
3712
+ block ._orig_forward = block .forward
3713
+ block .forward = types .MethodType (block_forward , block )
3714
+ block .attn ._orig_forward = block .attn .forward
3715
+ block .attn .forward = types .MethodType (sdpa_attn_forward , block .attn )
3716
+
3717
+
3718
+ class Qwen2VLVisionEmbMergerPatcher (ModelPatcher ):
3719
+ def __init__ (
3720
+ self ,
3721
+ config : "OnnxConfig" ,
3722
+ model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
3723
+ model_kwargs : Dict [str , Any ] = None ,
3724
+ ):
3725
+ model .__orig_forward = model .forward
3726
+
3727
+ # Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1118
3728
+ # added attention_mask input instead cu_lens for its internal calculation model (unsupported by tracing due to cycle with dynamic len)
3729
+ # separated patch_embed and rot_pos_emb calls for performing as part of another model
3730
+ def image_embed_forward (
3731
+ self , hidden_states : torch .Tensor , attention_mask : torch .Tensor , rotary_pos_emb : torch .Tensor
3732
+ ) -> torch .Tensor :
3733
+ for blk in self .blocks :
3734
+ hidden_states = blk (hidden_states , attention_mask = attention_mask , rotary_pos_emb = rotary_pos_emb )
3735
+ return self .merger (hidden_states )
3736
+
3737
+ model .forward = types .MethodType (image_embed_forward , model )
3738
+ super ().__init__ (config , model , model_kwargs )
3739
+
3740
+ def __enter__ (self ):
3741
+ patch_qwen2vl_vision_blocks (self ._model )
3742
+ super ().__enter__ ()
3743
+
3744
+ def __exit__ (self , exc_type , exc_value , traceback ):
3745
+ super ().__exit__ (exc_type , exc_value , traceback )
3746
+ self ._model .forward = self ._model .__orig_forward
3670
3747
for block in self ._model .blocks :
3671
- block ._orig_forward = block .forward
3672
- block .forward = types .MethodType (block_forward , block )
3673
- block .attn ._orig_forward = block .attn .forward
3674
- block .attn .forward = types .MethodType (sdpa_attn_forward , block .attn )
3748
+ block .forward = block ._orig_forward
3749
+ block .attn .forward = block .attn ._orig_forward
3750
+
3751
+
3752
+ class Qwen2_5_VLVisionEmbMergerPatcher (ModelPatcher ):
3753
+ def __init__ (
3754
+ self ,
3755
+ config : "OnnxConfig" ,
3756
+ model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
3757
+ model_kwargs : Dict [str , Any ] = None ,
3758
+ ):
3759
+ super ().__init__ (config , model , model_kwargs )
3760
+
3761
+ model .__orig_forward = model .forward
3762
+
3763
+ # Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1118
3764
+ # added attention_mask input instead cu_lens for its internal calculation model (unsupported by tracing due to cycle with dynamic len)
3765
+ # separated patch_embed and rot_pos_emb calls for performing as part of another model
3766
+ def image_embed_forward (
3767
+ self ,
3768
+ hidden_states : torch .Tensor ,
3769
+ attention_mask : torch .Tensor ,
3770
+ window_attention_mask : torch .Tensor ,
3771
+ window_index : torch .Tensor ,
3772
+ rotary_pos_emb : torch .Tensor ,
3773
+ ) -> torch .Tensor :
3774
+ seq_len = hidden_states .shape [0 ]
3775
+ hidden_states = hidden_states .reshape (seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
3776
+ hidden_states = hidden_states [window_index , :, :]
3777
+ hidden_states = hidden_states .reshape (seq_len , - 1 )
3778
+ rotary_pos_emb = rotary_pos_emb .reshape (seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
3779
+ rotary_pos_emb = rotary_pos_emb [window_index , :, :]
3780
+ rotary_pos_emb = rotary_pos_emb .reshape (seq_len , - 1 )
3781
+ emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
3782
+ position_embeddings = (emb .cos (), emb .sin ())
3783
+ for layer_num , blk in enumerate (self .blocks ):
3784
+ if layer_num in self .fullatt_block_indexes :
3785
+ attention_mask_now = attention_mask
3786
+ else :
3787
+ attention_mask_now = window_attention_mask
3788
+ hidden_states = blk (
3789
+ hidden_states , attention_mask = attention_mask_now , position_embeddings = position_embeddings
3790
+ )
3791
+
3792
+ hidden_states = self .merger (hidden_states )
3793
+ reverse_indices = torch .argsort (window_index )
3794
+ hidden_states = hidden_states [reverse_indices , :]
3795
+
3796
+ return hidden_states
3797
+
3798
+ model .forward = types .MethodType (image_embed_forward , model )
3799
+ super ().__init__ (config , model , model_kwargs )
3800
+
3801
+ def __enter__ (self ):
3802
+ patch_qwen2vl_vision_blocks (self ._model , force_new_behaviour = True )
3803
+ super ().__enter__ ()
3675
3804
3676
3805
def __exit__ (self , exc_type , exc_value , traceback ):
3677
3806
super ().__exit__ (exc_type , exc_value , traceback )
0 commit comments