@@ -3614,29 +3614,8 @@ def __exit__(self, exc_type, exc_value, traceback):
3614
3614
self ._model .forward = self ._model .__orig_forward
3615
3615
3616
3616
3617
- class Qwen2VLVisionEmbMergerPatcher (ModelPatcher ):
3618
- def __init__ (
3619
- self ,
3620
- config : "OnnxConfig" ,
3621
- model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
3622
- model_kwargs : Dict [str , Any ] = None ,
3623
- ):
3624
- model .__orig_forward = model .forward
3625
-
3626
- # Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1118
3627
- # added attention_mask input instead cu_lens for its internal calculation model (unsupported by tracing due to cycle with dynamic len)
3628
- # separated patch_embed and rot_pos_emb calls for performing as part of another model
3629
- def image_embed_forward (
3630
- self , hidden_states : torch .Tensor , attention_mask : torch .Tensor , rotary_pos_emb : torch .Tensor
3631
- ) -> torch .Tensor :
3632
- for blk in self .blocks :
3633
- hidden_states = blk (hidden_states , attention_mask = attention_mask , rotary_pos_emb = rotary_pos_emb )
3634
- return self .merger (hidden_states )
3635
-
3636
- model .forward = types .MethodType (image_embed_forward , model )
3637
- super ().__init__ (config , model , model_kwargs )
3638
-
3639
- def __enter__ (self ):
3617
+ def patch_qwen2vl_vision_blocks (model , force_new_behaviour = False ):
3618
+ if not force_new_behaviour and is_transformers_version ("<=" , "4.48.99" ):
3640
3619
# Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L390
3641
3620
# added attention_mask input instead of internal calculation (unsupported by tracing due to cycle with dynamic len)
3642
3621
def sdpa_attn_forward (
@@ -3667,11 +3646,162 @@ def block_forward(self, hidden_states, attention_mask, rotary_pos_emb) -> torch.
3667
3646
hidden_states = hidden_states + self .mlp (self .norm2 (hidden_states ))
3668
3647
return hidden_states
3669
3648
3649
+ else :
3650
+
3651
+ def sdpa_attn_forward (
3652
+ self ,
3653
+ hidden_states : torch .Tensor ,
3654
+ attention_mask : torch .Tensor ,
3655
+ rotary_pos_emb : torch .Tensor = None ,
3656
+ position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
3657
+ ):
3658
+ def rotate_half (x ):
3659
+ """Rotates half the hidden dims of the input."""
3660
+ x1 = x [..., : x .shape [- 1 ] // 2 ]
3661
+ x2 = x [..., x .shape [- 1 ] // 2 :]
3662
+ return torch .cat ((- x2 , x1 ), dim = - 1 )
3663
+
3664
+
3665
+ def apply_rotary_pos_emb_vision (
3666
+ q : torch .Tensor , k : torch .Tensor , cos : torch .Tensor , sin : torch .Tensor
3667
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
3668
+ orig_q_dtype = q .dtype
3669
+ orig_k_dtype = k .dtype
3670
+ q , k = q .float (), k .float ()
3671
+ cos , sin = cos .unsqueeze (- 2 ), sin .unsqueeze (- 2 )
3672
+ q_embed = (q * cos ) + (rotate_half (q ) * sin )
3673
+ k_embed = (k * cos ) + (rotate_half (k ) * sin )
3674
+ q_embed = q_embed .to (orig_q_dtype )
3675
+ k_embed = k_embed .to (orig_k_dtype )
3676
+ return q_embed , k_embed
3677
+
3678
+ seq_length = hidden_states .shape [0 ]
3679
+ q , k , v = self .qkv (hidden_states ).reshape (seq_length , 3 , self .num_heads , - 1 ).permute (1 , 0 , 2 , 3 ).unbind (0 )
3680
+ if position_embeddings is None :
3681
+ emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
3682
+ cos = emb .cos ().float ()
3683
+ sin = emb .sin ().float ()
3684
+ else :
3685
+ cos , sin = position_embeddings
3686
+ q , k = apply_rotary_pos_emb_vision (q , k , cos , sin )
3687
+ q = q .transpose (0 , 1 )
3688
+ k = k .transpose (0 , 1 )
3689
+ v = v .transpose (0 , 1 )
3690
+ attn_output = torch .nn .functional .scaled_dot_product_attention (q , k , v , attention_mask , dropout_p = 0.0 )
3691
+ attn_output = attn_output .transpose (0 , 1 )
3692
+ attn_output = attn_output .reshape (seq_length , - 1 )
3693
+ attn_output = self .proj (attn_output )
3694
+ return attn_output
3695
+
3696
+ def block_forward (
3697
+ self ,
3698
+ hidden_states ,
3699
+ attention_mask ,
3700
+ rotary_pos_emb : Optional [torch .Tensor ] = None ,
3701
+ position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
3702
+ ) -> torch .Tensor :
3703
+ hidden_states = hidden_states + self .attn (
3704
+ self .norm1 (hidden_states ),
3705
+ attention_mask = attention_mask ,
3706
+ rotary_pos_emb = rotary_pos_emb ,
3707
+ position_embeddings = position_embeddings ,
3708
+ )
3709
+ hidden_states = hidden_states + self .mlp (self .norm2 (hidden_states ))
3710
+ return hidden_states
3711
+
3712
+ for block in model .blocks :
3713
+ block ._orig_forward = block .forward
3714
+ block .forward = types .MethodType (block_forward , block )
3715
+ block .attn ._orig_forward = block .attn .forward
3716
+ block .attn .forward = types .MethodType (sdpa_attn_forward , block .attn )
3717
+
3718
+
3719
+ class Qwen2VLVisionEmbMergerPatcher (ModelPatcher ):
3720
+ def __init__ (
3721
+ self ,
3722
+ config : "OnnxConfig" ,
3723
+ model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
3724
+ model_kwargs : Dict [str , Any ] = None ,
3725
+ ):
3726
+ model .__orig_forward = model .forward
3727
+
3728
+ # Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1118
3729
+ # added attention_mask input instead cu_lens for its internal calculation model (unsupported by tracing due to cycle with dynamic len)
3730
+ # separated patch_embed and rot_pos_emb calls for performing as part of another model
3731
+ def image_embed_forward (
3732
+ self , hidden_states : torch .Tensor , attention_mask : torch .Tensor , rotary_pos_emb : torch .Tensor
3733
+ ) -> torch .Tensor :
3734
+ for blk in self .blocks :
3735
+ hidden_states = blk (hidden_states , attention_mask = attention_mask , rotary_pos_emb = rotary_pos_emb )
3736
+ return self .merger (hidden_states )
3737
+
3738
+ model .forward = types .MethodType (image_embed_forward , model )
3739
+ super ().__init__ (config , model , model_kwargs )
3740
+
3741
+ def __enter__ (self ):
3742
+ patch_qwen2vl_vision_blocks (self ._model )
3743
+ super ().__enter__ ()
3744
+
3745
+ def __exit__ (self , exc_type , exc_value , traceback ):
3746
+ super ().__exit__ (exc_type , exc_value , traceback )
3747
+ self ._model .forward = self ._model .__orig_forward
3670
3748
for block in self ._model .blocks :
3671
- block ._orig_forward = block .forward
3672
- block .forward = types .MethodType (block_forward , block )
3673
- block .attn ._orig_forward = block .attn .forward
3674
- block .attn .forward = types .MethodType (sdpa_attn_forward , block .attn )
3749
+ block .forward = block ._orig_forward
3750
+ block .attn .forward = block .attn ._orig_forward
3751
+
3752
+
3753
+ class Qwen2_5_VLVisionEmbMergerPatcher (ModelPatcher ):
3754
+ def __init__ (
3755
+ self ,
3756
+ config : "OnnxConfig" ,
3757
+ model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
3758
+ model_kwargs : Dict [str , Any ] = None ,
3759
+ ):
3760
+ super ().__init__ (config , model , model_kwargs )
3761
+
3762
+ model .__orig_forward = model .forward
3763
+
3764
+ # Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1118
3765
+ # added attention_mask input instead cu_lens for its internal calculation model (unsupported by tracing due to cycle with dynamic len)
3766
+ # separated patch_embed and rot_pos_emb calls for performing as part of another model
3767
+ def image_embed_forward (
3768
+ self ,
3769
+ hidden_states : torch .Tensor ,
3770
+ attention_mask : torch .Tensor ,
3771
+ window_attention_mask : torch .Tensor ,
3772
+ window_index : torch .Tensor ,
3773
+ rotary_pos_emb : torch .Tensor ,
3774
+ ) -> torch .Tensor :
3775
+ seq_len = hidden_states .shape [0 ]
3776
+ hidden_states = hidden_states .reshape (seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
3777
+ hidden_states = hidden_states [window_index , :, :]
3778
+ hidden_states = hidden_states .reshape (seq_len , - 1 )
3779
+ rotary_pos_emb = rotary_pos_emb .reshape (seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
3780
+ rotary_pos_emb = rotary_pos_emb [window_index , :, :]
3781
+ rotary_pos_emb = rotary_pos_emb .reshape (seq_len , - 1 )
3782
+ emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
3783
+ position_embeddings = (emb .cos (), emb .sin ())
3784
+ for layer_num , blk in enumerate (self .blocks ):
3785
+ if layer_num in self .fullatt_block_indexes :
3786
+ attention_mask_now = attention_mask
3787
+ else :
3788
+ attention_mask_now = window_attention_mask
3789
+ hidden_states = blk (
3790
+ hidden_states , attention_mask = attention_mask_now , position_embeddings = position_embeddings
3791
+ )
3792
+
3793
+ hidden_states = self .merger (hidden_states )
3794
+ reverse_indices = torch .argsort (window_index )
3795
+ hidden_states = hidden_states [reverse_indices , :]
3796
+
3797
+ return hidden_states
3798
+
3799
+ model .forward = types .MethodType (image_embed_forward , model )
3800
+ super ().__init__ (config , model , model_kwargs )
3801
+
3802
+ def __enter__ (self ):
3803
+ patch_qwen2vl_vision_blocks (self ._model , force_new_behaviour = True )
3804
+ super ().__enter__ ()
3675
3805
3676
3806
def __exit__ (self , exc_type , exc_value , traceback ):
3677
3807
super ().__exit__ (exc_type , exc_value , traceback )
0 commit comments