53
53
54
54
55
55
if TYPE_CHECKING :
56
- from PIL import Image
56
+ from PIL . Image import Image
57
57
58
58
59
59
logger = logging .getLogger (__name__ )
@@ -166,9 +166,6 @@ def prepare_inputs(
166
166
if past_len :
167
167
position_ids = position_ids [:, - inputs_embeds .shape [1 ] :]
168
168
169
- if self .config .model_type == "qwen2_vl" and position_ids .ndim != 3 :
170
- position_ids = np .repeat (np .expand_dims (position_ids , 0 ), 3 , axis = 0 )
171
-
172
169
inputs ["position_ids" ] = position_ids
173
170
174
171
if "beam_idx" in self .input_names :
@@ -2100,6 +2097,8 @@ def __init__(
2100
2097
quantization_config = quantization_config ,
2101
2098
** kwargs ,
2102
2099
)
2100
+ self .rope_deltas = None # cache rope_deltas here
2101
+
2103
2102
if is_transformers_version (">=" , "4.45.0" ):
2104
2103
from transformers .models .qwen2_vl .modeling_qwen2_vl import (
2105
2104
Qwen2VLForConditionalGeneration ,
@@ -2197,6 +2196,7 @@ def get_multimodal_embeddings(
2197
2196
pixel_values_videos = None ,
2198
2197
image_grid_thw = None ,
2199
2198
video_grid_thw = None ,
2199
+ cache_position = None ,
2200
2200
** kwargs ,
2201
2201
):
2202
2202
inputs_embeds = torch .from_numpy (self .get_text_embeddings (input_ids ))
@@ -2209,6 +2209,26 @@ def get_multimodal_embeddings(
2209
2209
video_embeds = torch .from_numpy (self .get_vision_embeddings (pixel_values_videos , video_grid_thw ))
2210
2210
video_mask = input_ids == self .config .video_token_id
2211
2211
inputs_embeds [video_mask ] = video_embeds
2212
+
2213
+ # if we get 4D attention mask we cannot calculate rope deltas anymore.
2214
+ if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask .ndim == 2 ):
2215
+ # calculate RoPE index once per generation in the pre-fill stage only
2216
+ if (cache_position is not None and cache_position [0 ] == 0 ) or self .rope_deltas is None :
2217
+ position_ids , rope_deltas = self .get_rope_index (
2218
+ input_ids , image_grid_thw , video_grid_thw , attention_mask
2219
+ )
2220
+ self .rope_deltas = rope_deltas
2221
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
2222
+ else :
2223
+ batch_size , seq_length , _ = inputs_embeds .shape
2224
+ delta = cache_position [0 ] + self .rope_deltas if cache_position is not None else 0
2225
+ position_ids = torch .arange (seq_length , device = inputs_embeds .device )
2226
+ position_ids = position_ids .view (1 , - 1 ).expand (batch_size , - 1 )
2227
+ if cache_position is not None : # otherwise `deltas` is an int `0`
2228
+ delta = delta .repeat_interleave (batch_size // delta .shape [0 ], dim = 0 )
2229
+ position_ids = position_ids .add (delta )
2230
+ position_ids = position_ids .unsqueeze (0 ).expand (3 , - 1 , - 1 )
2231
+
2212
2232
return inputs_embeds , attention_mask , position_ids
2213
2233
2214
2234
def forward (
0 commit comments