Skip to content

Commit 0307ec9

Browse files
qwen2 vl position ids
1 parent 35c47a2 commit 0307ec9

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

optimum/exporters/openvino/model_patcher.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -421,9 +421,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c
421421
offset = 0
422422
mask_shape = attention_mask.shape
423423
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
424-
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
425-
mask_slice
426-
)
424+
causal_mask[
425+
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
426+
] = mask_slice
427427

428428
if (
429429
self.config._attn_implementation == "sdpa"
@@ -2058,9 +2058,9 @@ def _dbrx_update_causal_mask_legacy(
20582058
offset = 0
20592059
mask_shape = attention_mask.shape
20602060
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
2061-
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
2062-
mask_slice
2063-
)
2061+
causal_mask[
2062+
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
2063+
] = mask_slice
20642064

20652065
if (
20662066
self.config._attn_implementation == "sdpa"

optimum/intel/openvino/modeling_visual_language.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353

5454

5555
if TYPE_CHECKING:
56-
from PIL import Image
56+
from PIL.Image import Image
5757

5858

5959
logger = logging.getLogger(__name__)
@@ -166,9 +166,6 @@ def prepare_inputs(
166166
if past_len:
167167
position_ids = position_ids[:, -inputs_embeds.shape[1] :]
168168

169-
if self.config.model_type == "qwen2_vl" and position_ids.ndim != 3:
170-
position_ids = np.repeat(np.expand_dims(position_ids, 0), 3, axis=0)
171-
172169
inputs["position_ids"] = position_ids
173170

174171
if "beam_idx" in self.input_names:
@@ -2228,6 +2225,9 @@ def forward(
22282225
rope_deltas=None,
22292226
**kwargs,
22302227
):
2228+
if position_ids is None and input_ids is not None:
2229+
position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
2230+
22312231
result = super().forward(
22322232
input_ids,
22332233
pixel_values,

0 commit comments

Comments
 (0)