Skip to content

Commit ad64bdf

Browse files
committed
reuse original methods if possile
1 parent d7ba440 commit ad64bdf

File tree

1 file changed

+10
-172
lines changed

1 file changed

+10
-172
lines changed

optimum/intel/openvino/modeling_visual_language.py

+10-172
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from dataclasses import dataclass
77
from pathlib import Path
88
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
9+
from types import MethodType
910

1011
import numpy as np
1112
import openvino as ov
@@ -2099,187 +2100,24 @@ def __init__(
20992100
**kwargs,
21002101
)
21012102
if is_transformers_version(">=", "4.45.0"):
2102-
from transformers.models.qwen2_vl.modeling_qwen2_vl import VisionRotaryEmbedding
2103+
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
2104+
VisionRotaryEmbedding,
2105+
Qwen2VLForConditionalGeneration,
2106+
)
21032107

21042108
self._rotary_pos_emb = VisionRotaryEmbedding(
21052109
self.config.vision_config.embed_dim // self.config.vision_config.num_heads // 2
21062110
)
2111+
self.get_rope_index = MethodType(Qwen2VLForConditionalGeneration.get_rope_index, self)
2112+
self.prepare_inputs_for_generation = MethodType(
2113+
Qwen2VLForConditionalGeneration.prepare_inputs_for_generation, self
2114+
)
21072115
else:
21082116
raise ValueError(
21092117
f"Initialization model for {self.config.model_type} required at least transformers >= 4.45"
21102118
)
21112119

2112-
def get_rope_index(
2113-
self,
2114-
input_ids: torch.LongTensor,
2115-
image_grid_thw: Optional[torch.LongTensor] = None,
2116-
video_grid_thw: Optional[torch.LongTensor] = None,
2117-
attention_mask: Optional[torch.Tensor] = None,
2118-
) -> Tuple[torch.Tensor, torch.Tensor]:
2119-
"""
2120-
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
2121-
"""
2122-
spatial_merge_size = self.config.vision_config.spatial_merge_size
2123-
image_token_id = self.config.image_token_id
2124-
video_token_id = self.config.video_token_id
2125-
vision_start_token_id = self.config.vision_start_token_id
2126-
mrope_position_deltas = []
2127-
if image_grid_thw is not None or video_grid_thw is not None:
2128-
total_input_ids = input_ids
2129-
position_ids = torch.ones(
2130-
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
2131-
)
2132-
image_index, video_index = 0, 0
2133-
for i, input_ids in enumerate(total_input_ids):
2134-
if attention_mask is not None:
2135-
input_ids = input_ids[attention_mask[i] == 1]
2136-
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
2137-
vision_tokens = input_ids[vision_start_indices + 1]
2138-
image_nums = (vision_tokens == image_token_id).sum()
2139-
video_nums = (vision_tokens == video_token_id).sum()
2140-
input_tokens = input_ids.tolist()
2141-
llm_pos_ids_list: list = []
2142-
st = 0
2143-
remain_images, remain_videos = image_nums, video_nums
2144-
for _ in range(image_nums + video_nums):
2145-
if image_token_id in input_tokens and remain_images > 0:
2146-
ed_image = input_tokens.index(image_token_id, st)
2147-
else:
2148-
ed_image = len(input_tokens) + 1
2149-
if video_token_id in input_tokens and remain_videos > 0:
2150-
ed_video = input_tokens.index(video_token_id, st)
2151-
else:
2152-
ed_video = len(input_tokens) + 1
2153-
if ed_image < ed_video:
2154-
t, h, w = (
2155-
image_grid_thw[image_index][0],
2156-
image_grid_thw[image_index][1],
2157-
image_grid_thw[image_index][2],
2158-
)
2159-
image_index += 1
2160-
remain_images -= 1
2161-
ed = ed_image
2162-
else:
2163-
t, h, w = (
2164-
video_grid_thw[video_index][0],
2165-
video_grid_thw[video_index][1],
2166-
video_grid_thw[video_index][2],
2167-
)
2168-
video_index += 1
2169-
remain_videos -= 1
2170-
ed = ed_video
2171-
llm_grid_t, llm_grid_h, llm_grid_w = (
2172-
t.item(),
2173-
h.item() // spatial_merge_size,
2174-
w.item() // spatial_merge_size,
2175-
)
2176-
text_len = ed - st
2177-
2178-
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
2179-
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
2180-
2181-
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
2182-
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
2183-
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
2184-
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
2185-
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
2186-
2187-
if st < len(input_tokens):
2188-
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
2189-
text_len = len(input_tokens) - st
2190-
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
2191-
2192-
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
2193-
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
2194-
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
2195-
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
2196-
return position_ids, mrope_position_deltas
2197-
else:
2198-
if attention_mask is not None:
2199-
position_ids = attention_mask.long().cumsum(-1) - 1
2200-
position_ids.masked_fill_(attention_mask == 0, 1)
2201-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
2202-
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
2203-
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
2204-
else:
2205-
position_ids = (
2206-
torch.arange(input_ids.shape[1], device=input_ids.device)
2207-
.view(1, 1, -1)
2208-
.expand(3, input_ids.shape[0], -1)
2209-
)
2210-
mrope_position_deltas = torch.zeros(
2211-
[input_ids.shape[0], 1],
2212-
device=input_ids.device,
2213-
dtype=input_ids.dtype,
2214-
)
2215-
2216-
return position_ids, mrope_position_deltas
2217-
2218-
def prepare_inputs_for_generation(
2219-
self,
2220-
input_ids,
2221-
past_key_values=None,
2222-
attention_mask=None,
2223-
inputs_embeds=None,
2224-
cache_position=None,
2225-
position_ids=None,
2226-
use_cache=True,
2227-
pixel_values=None,
2228-
pixel_values_videos=None,
2229-
image_grid_thw=None,
2230-
video_grid_thw=None,
2231-
**kwargs,
2232-
):
2233-
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
2234-
# Exception 1: when passing input_embeds, input_ids may be missing entries
2235-
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
2236-
if past_key_values is not None:
2237-
if inputs_embeds is not None: # Exception 1
2238-
input_ids = input_ids[:, -cache_position.shape[0] :]
2239-
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
2240-
input_ids = input_ids[:, cache_position]
2241-
2242-
rope_deltas = kwargs.get("rope_deltas", None)
2243-
if attention_mask is not None and position_ids is None:
2244-
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
2245-
position_ids, rope_deltas = self.get_rope_index(
2246-
input_ids, image_grid_thw, video_grid_thw, attention_mask
2247-
)
2248-
else:
2249-
batch_size, seq_length = input_ids.shape
2250-
delta = (
2251-
cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0
2252-
)
2253-
position_ids = torch.arange(seq_length, device=input_ids.device)
2254-
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
2255-
position_ids = position_ids.add(delta)
2256-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
2257-
2258-
if cache_position[0] != 0:
2259-
pixel_values = None
2260-
pixel_values_videos = None
2261-
2262-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
2263-
if inputs_embeds is not None and cache_position[0] == 0:
2264-
model_inputs = {"inputs_embeds": inputs_embeds}
2265-
else:
2266-
model_inputs = {"input_ids": input_ids}
2267-
2268-
model_inputs.update(
2269-
{
2270-
"position_ids": position_ids,
2271-
"past_key_values": past_key_values,
2272-
"use_cache": use_cache,
2273-
"attention_mask": attention_mask,
2274-
"pixel_values": pixel_values,
2275-
"pixel_values_videos": pixel_values_videos,
2276-
"image_grid_thw": image_grid_thw,
2277-
"video_grid_thw": video_grid_thw,
2278-
"rope_deltas": rope_deltas,
2279-
}
2280-
)
2281-
return model_inputs
2282-
2120+
# Copied from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1602
22832121
def _update_model_kwargs_for_generation(
22842122
self,
22852123
outputs: ModelOutput,

0 commit comments

Comments
 (0)