|
6 | 6 | from dataclasses import dataclass
|
7 | 7 | from pathlib import Path
|
8 | 8 | from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
| 9 | +from types import MethodType |
9 | 10 |
|
10 | 11 | import numpy as np
|
11 | 12 | import openvino as ov
|
@@ -2099,187 +2100,24 @@ def __init__(
|
2099 | 2100 | **kwargs,
|
2100 | 2101 | )
|
2101 | 2102 | if is_transformers_version(">=", "4.45.0"):
|
2102 |
| - from transformers.models.qwen2_vl.modeling_qwen2_vl import VisionRotaryEmbedding |
| 2103 | + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( |
| 2104 | + VisionRotaryEmbedding, |
| 2105 | + Qwen2VLForConditionalGeneration, |
| 2106 | + ) |
2103 | 2107 |
|
2104 | 2108 | self._rotary_pos_emb = VisionRotaryEmbedding(
|
2105 | 2109 | self.config.vision_config.embed_dim // self.config.vision_config.num_heads // 2
|
2106 | 2110 | )
|
| 2111 | + self.get_rope_index = MethodType(Qwen2VLForConditionalGeneration.get_rope_index, self) |
| 2112 | + self.prepare_inputs_for_generation = MethodType( |
| 2113 | + Qwen2VLForConditionalGeneration.prepare_inputs_for_generation, self |
| 2114 | + ) |
2107 | 2115 | else:
|
2108 | 2116 | raise ValueError(
|
2109 | 2117 | f"Initialization model for {self.config.model_type} required at least transformers >= 4.45"
|
2110 | 2118 | )
|
2111 | 2119 |
|
2112 |
| - def get_rope_index( |
2113 |
| - self, |
2114 |
| - input_ids: torch.LongTensor, |
2115 |
| - image_grid_thw: Optional[torch.LongTensor] = None, |
2116 |
| - video_grid_thw: Optional[torch.LongTensor] = None, |
2117 |
| - attention_mask: Optional[torch.Tensor] = None, |
2118 |
| - ) -> Tuple[torch.Tensor, torch.Tensor]: |
2119 |
| - """ |
2120 |
| - Calculate the 3D rope index based on image and video's temporal, height and width in LLM. |
2121 |
| - """ |
2122 |
| - spatial_merge_size = self.config.vision_config.spatial_merge_size |
2123 |
| - image_token_id = self.config.image_token_id |
2124 |
| - video_token_id = self.config.video_token_id |
2125 |
| - vision_start_token_id = self.config.vision_start_token_id |
2126 |
| - mrope_position_deltas = [] |
2127 |
| - if image_grid_thw is not None or video_grid_thw is not None: |
2128 |
| - total_input_ids = input_ids |
2129 |
| - position_ids = torch.ones( |
2130 |
| - 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device |
2131 |
| - ) |
2132 |
| - image_index, video_index = 0, 0 |
2133 |
| - for i, input_ids in enumerate(total_input_ids): |
2134 |
| - if attention_mask is not None: |
2135 |
| - input_ids = input_ids[attention_mask[i] == 1] |
2136 |
| - vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) |
2137 |
| - vision_tokens = input_ids[vision_start_indices + 1] |
2138 |
| - image_nums = (vision_tokens == image_token_id).sum() |
2139 |
| - video_nums = (vision_tokens == video_token_id).sum() |
2140 |
| - input_tokens = input_ids.tolist() |
2141 |
| - llm_pos_ids_list: list = [] |
2142 |
| - st = 0 |
2143 |
| - remain_images, remain_videos = image_nums, video_nums |
2144 |
| - for _ in range(image_nums + video_nums): |
2145 |
| - if image_token_id in input_tokens and remain_images > 0: |
2146 |
| - ed_image = input_tokens.index(image_token_id, st) |
2147 |
| - else: |
2148 |
| - ed_image = len(input_tokens) + 1 |
2149 |
| - if video_token_id in input_tokens and remain_videos > 0: |
2150 |
| - ed_video = input_tokens.index(video_token_id, st) |
2151 |
| - else: |
2152 |
| - ed_video = len(input_tokens) + 1 |
2153 |
| - if ed_image < ed_video: |
2154 |
| - t, h, w = ( |
2155 |
| - image_grid_thw[image_index][0], |
2156 |
| - image_grid_thw[image_index][1], |
2157 |
| - image_grid_thw[image_index][2], |
2158 |
| - ) |
2159 |
| - image_index += 1 |
2160 |
| - remain_images -= 1 |
2161 |
| - ed = ed_image |
2162 |
| - else: |
2163 |
| - t, h, w = ( |
2164 |
| - video_grid_thw[video_index][0], |
2165 |
| - video_grid_thw[video_index][1], |
2166 |
| - video_grid_thw[video_index][2], |
2167 |
| - ) |
2168 |
| - video_index += 1 |
2169 |
| - remain_videos -= 1 |
2170 |
| - ed = ed_video |
2171 |
| - llm_grid_t, llm_grid_h, llm_grid_w = ( |
2172 |
| - t.item(), |
2173 |
| - h.item() // spatial_merge_size, |
2174 |
| - w.item() // spatial_merge_size, |
2175 |
| - ) |
2176 |
| - text_len = ed - st |
2177 |
| - |
2178 |
| - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
2179 |
| - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) |
2180 |
| - |
2181 |
| - t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() |
2182 |
| - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() |
2183 |
| - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() |
2184 |
| - llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) |
2185 |
| - st = ed + llm_grid_t * llm_grid_h * llm_grid_w |
2186 |
| - |
2187 |
| - if st < len(input_tokens): |
2188 |
| - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
2189 |
| - text_len = len(input_tokens) - st |
2190 |
| - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) |
2191 |
| - |
2192 |
| - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) |
2193 |
| - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) |
2194 |
| - mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) |
2195 |
| - mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) |
2196 |
| - return position_ids, mrope_position_deltas |
2197 |
| - else: |
2198 |
| - if attention_mask is not None: |
2199 |
| - position_ids = attention_mask.long().cumsum(-1) - 1 |
2200 |
| - position_ids.masked_fill_(attention_mask == 0, 1) |
2201 |
| - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) |
2202 |
| - max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] |
2203 |
| - mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] |
2204 |
| - else: |
2205 |
| - position_ids = ( |
2206 |
| - torch.arange(input_ids.shape[1], device=input_ids.device) |
2207 |
| - .view(1, 1, -1) |
2208 |
| - .expand(3, input_ids.shape[0], -1) |
2209 |
| - ) |
2210 |
| - mrope_position_deltas = torch.zeros( |
2211 |
| - [input_ids.shape[0], 1], |
2212 |
| - device=input_ids.device, |
2213 |
| - dtype=input_ids.dtype, |
2214 |
| - ) |
2215 |
| - |
2216 |
| - return position_ids, mrope_position_deltas |
2217 |
| - |
2218 |
| - def prepare_inputs_for_generation( |
2219 |
| - self, |
2220 |
| - input_ids, |
2221 |
| - past_key_values=None, |
2222 |
| - attention_mask=None, |
2223 |
| - inputs_embeds=None, |
2224 |
| - cache_position=None, |
2225 |
| - position_ids=None, |
2226 |
| - use_cache=True, |
2227 |
| - pixel_values=None, |
2228 |
| - pixel_values_videos=None, |
2229 |
| - image_grid_thw=None, |
2230 |
| - video_grid_thw=None, |
2231 |
| - **kwargs, |
2232 |
| - ): |
2233 |
| - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens |
2234 |
| - # Exception 1: when passing input_embeds, input_ids may be missing entries |
2235 |
| - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here |
2236 |
| - if past_key_values is not None: |
2237 |
| - if inputs_embeds is not None: # Exception 1 |
2238 |
| - input_ids = input_ids[:, -cache_position.shape[0] :] |
2239 |
| - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) |
2240 |
| - input_ids = input_ids[:, cache_position] |
2241 |
| - |
2242 |
| - rope_deltas = kwargs.get("rope_deltas", None) |
2243 |
| - if attention_mask is not None and position_ids is None: |
2244 |
| - if cache_position is None or (cache_position is not None and cache_position[0] == 0): |
2245 |
| - position_ids, rope_deltas = self.get_rope_index( |
2246 |
| - input_ids, image_grid_thw, video_grid_thw, attention_mask |
2247 |
| - ) |
2248 |
| - else: |
2249 |
| - batch_size, seq_length = input_ids.shape |
2250 |
| - delta = ( |
2251 |
| - cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0 |
2252 |
| - ) |
2253 |
| - position_ids = torch.arange(seq_length, device=input_ids.device) |
2254 |
| - position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
2255 |
| - position_ids = position_ids.add(delta) |
2256 |
| - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
2257 |
| - |
2258 |
| - if cache_position[0] != 0: |
2259 |
| - pixel_values = None |
2260 |
| - pixel_values_videos = None |
2261 |
| - |
2262 |
| - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step |
2263 |
| - if inputs_embeds is not None and cache_position[0] == 0: |
2264 |
| - model_inputs = {"inputs_embeds": inputs_embeds} |
2265 |
| - else: |
2266 |
| - model_inputs = {"input_ids": input_ids} |
2267 |
| - |
2268 |
| - model_inputs.update( |
2269 |
| - { |
2270 |
| - "position_ids": position_ids, |
2271 |
| - "past_key_values": past_key_values, |
2272 |
| - "use_cache": use_cache, |
2273 |
| - "attention_mask": attention_mask, |
2274 |
| - "pixel_values": pixel_values, |
2275 |
| - "pixel_values_videos": pixel_values_videos, |
2276 |
| - "image_grid_thw": image_grid_thw, |
2277 |
| - "video_grid_thw": video_grid_thw, |
2278 |
| - "rope_deltas": rope_deltas, |
2279 |
| - } |
2280 |
| - ) |
2281 |
| - return model_inputs |
2282 |
| - |
| 2120 | + # Copied from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1602 |
2283 | 2121 | def _update_model_kwargs_for_generation(
|
2284 | 2122 | self,
|
2285 | 2123 | outputs: ModelOutput,
|
|
0 commit comments