@@ -3909,29 +3909,8 @@ def __exit__(self, exc_type, exc_value, traceback):
3909
3909
self ._model .forward = self ._model .__orig_forward
3910
3910
3911
3911
3912
- class Qwen2VLVisionEmbMergerPatcher (ModelPatcher ):
3913
- def __init__ (
3914
- self ,
3915
- config : "OnnxConfig" ,
3916
- model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
3917
- model_kwargs : Dict [str , Any ] = None ,
3918
- ):
3919
- model .__orig_forward = model .forward
3920
-
3921
- # Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1118
3922
- # added attention_mask input instead cu_lens for its internal calculation model (unsupported by tracing due to cycle with dynamic len)
3923
- # separated patch_embed and rot_pos_emb calls for performing as part of another model
3924
- def image_embed_forward (
3925
- self , hidden_states : torch .Tensor , attention_mask : torch .Tensor , rotary_pos_emb : torch .Tensor
3926
- ) -> torch .Tensor :
3927
- for blk in self .blocks :
3928
- hidden_states = blk (hidden_states , attention_mask = attention_mask , rotary_pos_emb = rotary_pos_emb )
3929
- return self .merger (hidden_states )
3930
-
3931
- model .forward = types .MethodType (image_embed_forward , model )
3932
- super ().__init__ (config , model , model_kwargs )
3933
-
3934
- def __enter__ (self ):
3912
+ def patch_qwen2vl_vision_blocks (model , force_new_behaviour = False ):
3913
+ if not force_new_behaviour and is_transformers_version ("<=" , "4.48.99" ):
3935
3914
# Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L390
3936
3915
# added attention_mask input instead of internal calculation (unsupported by tracing due to cycle with dynamic len)
3937
3916
def sdpa_attn_forward (
@@ -3976,11 +3955,165 @@ def block_forward(self, hidden_states, attention_mask, rotary_pos_emb) -> torch.
3976
3955
hidden_states = hidden_states + self .mlp (self .norm2 (hidden_states ))
3977
3956
return hidden_states
3978
3957
3958
+ else :
3959
+ # Modified from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L391
3960
+ # added attention_mask input instead of internal calculation (unsupported by tracing due to cycle with dynamic len)
3961
+ def sdpa_attn_forward (
3962
+ self ,
3963
+ hidden_states : torch .Tensor ,
3964
+ attention_mask : torch .Tensor ,
3965
+ rotary_pos_emb : torch .Tensor = None ,
3966
+ position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
3967
+ ):
3968
+ def rotate_half (x ):
3969
+ """Rotates half the hidden dims of the input."""
3970
+ x1 = x [..., : x .shape [- 1 ] // 2 ]
3971
+ x2 = x [..., x .shape [- 1 ] // 2 :]
3972
+ return torch .cat ((- x2 , x1 ), dim = - 1 )
3973
+
3974
+ def apply_rotary_pos_emb_vision (
3975
+ q : torch .Tensor , k : torch .Tensor , cos : torch .Tensor , sin : torch .Tensor
3976
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
3977
+ orig_q_dtype = q .dtype
3978
+ orig_k_dtype = k .dtype
3979
+ q , k = q .float (), k .float ()
3980
+ cos , sin = cos .unsqueeze (- 2 ), sin .unsqueeze (- 2 )
3981
+ q_embed = (q * cos ) + (rotate_half (q ) * sin )
3982
+ k_embed = (k * cos ) + (rotate_half (k ) * sin )
3983
+ q_embed = q_embed .to (orig_q_dtype )
3984
+ k_embed = k_embed .to (orig_k_dtype )
3985
+ return q_embed , k_embed
3986
+
3987
+ seq_length = hidden_states .shape [0 ]
3988
+ q , k , v = self .qkv (hidden_states ).reshape (seq_length , 3 , self .num_heads , - 1 ).permute (1 , 0 , 2 , 3 ).unbind (0 )
3989
+ if position_embeddings is None :
3990
+ emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
3991
+ cos = emb .cos ().float ()
3992
+ sin = emb .sin ().float ()
3993
+ else :
3994
+ cos , sin = position_embeddings
3995
+ q , k = apply_rotary_pos_emb_vision (q , k , cos , sin )
3996
+ q = q .transpose (0 , 1 )
3997
+ k = k .transpose (0 , 1 )
3998
+ v = v .transpose (0 , 1 )
3999
+ attn_output = torch .nn .functional .scaled_dot_product_attention (q , k , v , attention_mask , dropout_p = 0.0 )
4000
+ attn_output = attn_output .transpose (0 , 1 )
4001
+ attn_output = attn_output .reshape (seq_length , - 1 )
4002
+ attn_output = self .proj (attn_output )
4003
+ return attn_output
4004
+
4005
+ # Modified from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L446
4006
+ # added attention_mask input propagation to self.attn
4007
+ def block_forward (
4008
+ self ,
4009
+ hidden_states ,
4010
+ attention_mask ,
4011
+ rotary_pos_emb : Optional [torch .Tensor ] = None ,
4012
+ position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
4013
+ ) -> torch .Tensor :
4014
+ hidden_states = hidden_states + self .attn (
4015
+ self .norm1 (hidden_states ),
4016
+ attention_mask = attention_mask ,
4017
+ rotary_pos_emb = rotary_pos_emb ,
4018
+ position_embeddings = position_embeddings ,
4019
+ )
4020
+ hidden_states = hidden_states + self .mlp (self .norm2 (hidden_states ))
4021
+ return hidden_states
4022
+
4023
+ for block in model .blocks :
4024
+ block ._orig_forward = block .forward
4025
+ block .forward = types .MethodType (block_forward , block )
4026
+ block .attn ._orig_forward = block .attn .forward
4027
+ block .attn .forward = types .MethodType (sdpa_attn_forward , block .attn )
4028
+
4029
+
4030
+ class Qwen2VLVisionEmbMergerPatcher (ModelPatcher ):
4031
+ def __init__ (
4032
+ self ,
4033
+ config : "OnnxConfig" ,
4034
+ model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
4035
+ model_kwargs : Dict [str , Any ] = None ,
4036
+ ):
4037
+ model .__orig_forward = model .forward
4038
+
4039
+ # Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1118
4040
+ # added attention_mask input instead cu_lens for its internal calculation model (unsupported by tracing due to cycle with dynamic len)
4041
+ # separated patch_embed and rot_pos_emb calls for performing as part of another model
4042
+ def image_embed_forward (
4043
+ self , hidden_states : torch .Tensor , attention_mask : torch .Tensor , rotary_pos_emb : torch .Tensor
4044
+ ) -> torch .Tensor :
4045
+ for blk in self .blocks :
4046
+ hidden_states = blk (hidden_states , attention_mask = attention_mask , rotary_pos_emb = rotary_pos_emb )
4047
+ return self .merger (hidden_states )
4048
+
4049
+ model .forward = types .MethodType (image_embed_forward , model )
4050
+ super ().__init__ (config , model , model_kwargs )
4051
+
4052
+ def __enter__ (self ):
4053
+ patch_qwen2vl_vision_blocks (self ._model )
4054
+ super ().__enter__ ()
4055
+
4056
+ def __exit__ (self , exc_type , exc_value , traceback ):
4057
+ super ().__exit__ (exc_type , exc_value , traceback )
4058
+ self ._model .forward = self ._model .__orig_forward
3979
4059
for block in self ._model .blocks :
3980
- block ._orig_forward = block .forward
3981
- block .forward = types .MethodType (block_forward , block )
3982
- block .attn ._orig_forward = block .attn .forward
3983
- block .attn .forward = types .MethodType (sdpa_attn_forward , block .attn )
4060
+ block .forward = block ._orig_forward
4061
+ block .attn .forward = block .attn ._orig_forward
4062
+
4063
+
4064
+ class Qwen2_5_VLVisionEmbMergerPatcher (ModelPatcher ):
4065
+ def __init__ (
4066
+ self ,
4067
+ config : "OnnxConfig" ,
4068
+ model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
4069
+ model_kwargs : Dict [str , Any ] = None ,
4070
+ ):
4071
+ super ().__init__ (config , model , model_kwargs )
4072
+
4073
+ model .__orig_forward = model .forward
4074
+
4075
+ # Modified from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L405
4076
+ # added attention_mask and window_attention_mask inputs instead cu_lens and window_cu_lens processing for its internal calculation model
4077
+ # (unsupported by tracing due to cycle with dynamic len)
4078
+ # separated patch_embed and rot_pos_emb calls for performing as part of another model
4079
+ def image_embed_forward (
4080
+ self ,
4081
+ hidden_states : torch .Tensor ,
4082
+ attention_mask : torch .Tensor ,
4083
+ window_attention_mask : torch .Tensor ,
4084
+ window_index : torch .Tensor ,
4085
+ rotary_pos_emb : torch .Tensor ,
4086
+ ) -> torch .Tensor :
4087
+ seq_len = hidden_states .shape [0 ]
4088
+ hidden_states = hidden_states .reshape (seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
4089
+ hidden_states = hidden_states [window_index , :, :]
4090
+ hidden_states = hidden_states .reshape (seq_len , - 1 )
4091
+ rotary_pos_emb = rotary_pos_emb .reshape (seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
4092
+ rotary_pos_emb = rotary_pos_emb [window_index , :, :]
4093
+ rotary_pos_emb = rotary_pos_emb .reshape (seq_len , - 1 )
4094
+ emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
4095
+ position_embeddings = (emb .cos (), emb .sin ())
4096
+ for layer_num , blk in enumerate (self .blocks ):
4097
+ if layer_num in self .fullatt_block_indexes :
4098
+ attention_mask_now = attention_mask
4099
+ else :
4100
+ attention_mask_now = window_attention_mask
4101
+ hidden_states = blk (
4102
+ hidden_states , attention_mask = attention_mask_now , position_embeddings = position_embeddings
4103
+ )
4104
+
4105
+ hidden_states = self .merger (hidden_states )
4106
+ reverse_indices = torch .argsort (window_index )
4107
+ hidden_states = hidden_states [reverse_indices , :]
4108
+
4109
+ return hidden_states
4110
+
4111
+ model .forward = types .MethodType (image_embed_forward , model )
4112
+ super ().__init__ (config , model , model_kwargs )
4113
+
4114
+ def __enter__ (self ):
4115
+ patch_qwen2vl_vision_blocks (self ._model , force_new_behaviour = True )
4116
+ super ().__enter__ ()
3984
4117
3985
4118
def __exit__ (self , exc_type , exc_value , traceback ):
3986
4119
super ().__exit__ (exc_type , exc_value , traceback )
0 commit comments