Skip to content

Commit c6c4a25

Browse files
latest qwen2 vl position_ids formula
1 parent 35c47a2 commit c6c4a25

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

optimum/exporters/openvino/model_patcher.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -421,9 +421,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c
421421
offset = 0
422422
mask_shape = attention_mask.shape
423423
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
424-
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
425-
mask_slice
426-
)
424+
causal_mask[
425+
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
426+
] = mask_slice
427427

428428
if (
429429
self.config._attn_implementation == "sdpa"
@@ -2058,9 +2058,9 @@ def _dbrx_update_causal_mask_legacy(
20582058
offset = 0
20592059
mask_shape = attention_mask.shape
20602060
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
2061-
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
2062-
mask_slice
2063-
)
2061+
causal_mask[
2062+
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
2063+
] = mask_slice
20642064

20652065
if (
20662066
self.config._attn_implementation == "sdpa"

optimum/intel/openvino/modeling_visual_language.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353

5454

5555
if TYPE_CHECKING:
56-
from PIL import Image
56+
from PIL.Image import Image
5757

5858

5959
logger = logging.getLogger(__name__)
@@ -166,9 +166,6 @@ def prepare_inputs(
166166
if past_len:
167167
position_ids = position_ids[:, -inputs_embeds.shape[1] :]
168168

169-
if self.config.model_type == "qwen2_vl" and position_ids.ndim != 3:
170-
position_ids = np.repeat(np.expand_dims(position_ids, 0), 3, axis=0)
171-
172169
inputs["position_ids"] = position_ids
173170

174171
if "beam_idx" in self.input_names:
@@ -2100,6 +2097,8 @@ def __init__(
21002097
quantization_config=quantization_config,
21012098
**kwargs,
21022099
)
2100+
self.rope_deltas = None # cache rope_deltas here
2101+
21032102
if is_transformers_version(">=", "4.45.0"):
21042103
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
21052104
Qwen2VLForConditionalGeneration,
@@ -2197,6 +2196,7 @@ def get_multimodal_embeddings(
21972196
pixel_values_videos=None,
21982197
image_grid_thw=None,
21992198
video_grid_thw=None,
2199+
cache_position=None,
22002200
**kwargs,
22012201
):
22022202
inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids))
@@ -2209,6 +2209,26 @@ def get_multimodal_embeddings(
22092209
video_embeds = torch.from_numpy(self.get_vision_embeddings(pixel_values_videos, video_grid_thw))
22102210
video_mask = input_ids == self.config.video_token_id
22112211
inputs_embeds[video_mask] = video_embeds
2212+
2213+
# if we get 4D attention mask we cannot calculate rope deltas anymore.
2214+
if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2):
2215+
# calculate RoPE index once per generation in the pre-fill stage only
2216+
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
2217+
position_ids, rope_deltas = self.get_rope_index(
2218+
input_ids, image_grid_thw, video_grid_thw, attention_mask
2219+
)
2220+
self.rope_deltas = rope_deltas
2221+
# then use the prev pre-calculated rope-deltas to get the correct position ids
2222+
else:
2223+
batch_size, seq_length, _ = inputs_embeds.shape
2224+
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
2225+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
2226+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
2227+
if cache_position is not None: # otherwise `deltas` is an int `0`
2228+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
2229+
position_ids = position_ids.add(delta)
2230+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
2231+
22122232
return inputs_embeds, attention_mask, position_ids
22132233

22142234
def forward(

0 commit comments

Comments
 (0)