@@ -3935,14 +3935,28 @@ def __enter__(self):
3935
3935
# Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L390
3936
3936
# added attention_mask input instead of internal calculation (unsupported by tracing due to cycle with dynamic len)
3937
3937
def sdpa_attn_forward (
3938
- self , hidden_states : torch .Tensor , attention_mask : torch .Tensor , rotary_pos_emb : torch .Tensor = None
3938
+ self ,
3939
+ hidden_states : torch .Tensor ,
3940
+ attention_mask : torch .Tensor ,
3941
+ rotary_pos_emb : torch .Tensor = None ,
3942
+ position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
3939
3943
) -> torch .Tensor :
3940
3944
from transformers .models .qwen2_vl .modeling_qwen2_vl import apply_rotary_pos_emb_vision
3941
3945
3942
3946
seq_length = hidden_states .shape [0 ]
3943
3947
q , k , v = self .qkv (hidden_states ).reshape (seq_length , 3 , self .num_heads , - 1 ).permute (1 , 0 , 2 , 3 ).unbind (0 )
3944
- q = apply_rotary_pos_emb_vision (q .unsqueeze (0 ), rotary_pos_emb ).squeeze (0 )
3945
- k = apply_rotary_pos_emb_vision (k .unsqueeze (0 ), rotary_pos_emb ).squeeze (0 )
3948
+
3949
+ if is_transformers_version (">=" , "4.49" ):
3950
+ if position_embeddings is None :
3951
+ emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
3952
+ cos = emb .cos ().float ()
3953
+ sin = emb .sin ().float ()
3954
+ else :
3955
+ cos , sin = position_embeddings
3956
+ q , k = apply_rotary_pos_emb_vision (q , k , cos , sin )
3957
+ else :
3958
+ q = apply_rotary_pos_emb_vision (q .unsqueeze (0 ), rotary_pos_emb ).squeeze (0 )
3959
+ k = apply_rotary_pos_emb_vision (k .unsqueeze (0 ), rotary_pos_emb ).squeeze (0 )
3946
3960
3947
3961
q = q .transpose (0 , 1 )
3948
3962
k = k .transpose (0 , 1 )
0 commit comments