Skip to content

Commit c166e5f

Browse files
committed
Update the model
1 parent 12a7134 commit c166e5f

File tree

5 files changed

+324
-24
lines changed

5 files changed

+324
-24
lines changed

.github/workflows/causal_lm_cpp.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ jobs:
716716
source ./ov/setupvars.sh
717717
&& ./build/samples/cpp/visual_language_chat/visual_language_chat ./miniCPM-V-2_6/ d5fbbd1a-d484-415c-88cb-9986625b7b11
718718
<<< $'What is on the image?\nWhat is special on the image?'
719-
timeout-minutes: 32
719+
timeout-minutes: 300
720720
721721
cpp-continuous-batching-ubuntu:
722722
runs-on: ubuntu-20.04-8-cores

samples/cpp/visual_language_chat/export_MiniCPM-V-2_6.py

+245-11
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,58 @@
99
from transformers import AutoModel, AutoTokenizer, AutoProcessor, TextIteratorStreamer
1010
from transformers.generation import GenerationMixin
1111
from transformers import AutoConfig, GenerationConfig
12-
from transformers.modeling_outputs import CausalLMOutputWithPast
12+
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPooling
13+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
1314
from pathlib import Path
1415
from huggingface_hub import snapshot_download
1516
import types
16-
from typing import Optional, Tuple, List
17+
from typing import Optional, Tuple, List, Union
1718
from openvino.runtime import opset13
1819
import openvino as ov
1920
import openvino_tokenizers
2021
import numpy as np
2122
import gc
23+
from openvino.runtime.passes import Manager, MatcherPass, WrapType, Matcher
24+
import time
2225

2326
text_emb_path = Path("embed_tokens.xml")
2427
image_emb_path = Path("image_encoder.xml")
2528
resampler_path = Path("resampler.xml")
2629
llm_path = Path("language_model.xml")
2730

31+
class InsertSlice(MatcherPass):
32+
def __init__(self):
33+
MatcherPass.__init__(self)
34+
self.model_changed = False
35+
36+
param = WrapType("opset10.Result")
37+
38+
def callback(matcher: Matcher) -> bool:
39+
root = matcher.get_match_root()
40+
if root is None:
41+
return False
42+
if len(root.get_output_partial_shape(0)) == 3:
43+
parent = root.input_value(0).get_node()
44+
grand_parent = parent.input_value(0).get_node()
45+
46+
grand_parent_output = parent.input(0).get_source_output()
47+
consumers = grand_parent_output.get_target_inputs()
48+
start = np.array([0, -1, 0], dtype=np.int32)
49+
stop = np.array([1, -2, grand_parent_output.get_partial_shape()[-1].get_length()], dtype=np.int32)
50+
step = np.array([1, -1, 1], dtype=np.int32)
51+
axes = np.array([0, 1, 2], dtype=np.int32)
52+
slice = opset13.slice(grand_parent, start, stop, step, axes, name="inserted_slice")
53+
for consumer in consumers:
54+
consumer.replace_source_output(slice.output(0))
55+
self.model_changed = True
56+
# Use new operation for additional matching
57+
self.register_new_node(slice)
58+
print("applied slice for lm head")
59+
60+
return True
61+
62+
self.register_matcher(Matcher(param, "InsertSlice"), callback)
63+
2864

2965
def model_has_state(ov_model: ov.Model):
3066
return len(ov_model.get_sinks()) > 0
@@ -324,13 +360,151 @@ def convert_vision_encoder(model, model_dir):
324360
tgt_sizes = torch.tensor([[23, 45]])
325361
if not (model_dir / image_emb_path).exists():
326362
print("⌛ Convert Image embedding model")
363+
def siglip_vis_embed_forward(
364+
self,
365+
pixel_values: torch.FloatTensor,
366+
patch_attention_mask: torch.BoolTensor,
367+
tgt_sizes: Optional[torch.IntTensor] = None,
368+
position_ids: Optional[torch.FloatTensor] = None,
369+
) -> torch.Tensor:
370+
patch_embeds = self.patch_embedding(pixel_values)
371+
embeddings = patch_embeds.flatten(2).transpose(1, 2)
372+
373+
if position_ids is None:
374+
batch_size = pixel_values.size(0)
375+
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
376+
max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
377+
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
378+
position_ids = torch.full(
379+
size=(
380+
batch_size,
381+
max_nb_patches_h * max_nb_patches_w,
382+
),
383+
fill_value=0,
384+
)
385+
386+
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
387+
if tgt_sizes is not None:
388+
nb_patches_h = tgt_sizes[batch_idx][0]
389+
nb_patches_w = tgt_sizes[batch_idx][1]
390+
else:
391+
nb_patches_h = p_attn_mask[:, 0].sum()
392+
nb_patches_w = p_attn_mask[0].sum()
393+
394+
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
395+
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
396+
397+
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
398+
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
399+
400+
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
401+
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
402+
403+
position_ids = position_ids.to(self.position_embedding.weight.device)
404+
405+
embeddings = embeddings + self.position_embedding(position_ids)
406+
return embeddings
407+
408+
def siglip_attn_forward(
409+
self,
410+
hidden_states: torch.Tensor,
411+
attention_mask: Optional[torch.Tensor] = None,
412+
output_attentions: Optional[bool] = False,
413+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
414+
"""Input shape: Batch x Time x Channel"""
415+
416+
batch_size, q_len, _ = hidden_states.size()
417+
418+
query_states = self.q_proj(hidden_states)
419+
key_states = self.k_proj(hidden_states)
420+
value_states = self.v_proj(hidden_states)
421+
422+
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
423+
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
424+
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
425+
426+
attn_output = torch.nn.functional.scaled_dot_product_attention(
427+
query_states, key_states, value_states, attention_mask, is_causal=attention_mask is None
428+
)
429+
430+
attn_output = attn_output.transpose(1, 2).contiguous()
431+
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
432+
433+
attn_output = self.out_proj(attn_output)
434+
435+
return attn_output, None
436+
437+
def siglip_transformer_forward(
438+
self,
439+
pixel_values,
440+
patch_attention_mask: Optional[torch.BoolTensor] = None,
441+
tgt_sizes: Optional[torch.IntTensor] = None,
442+
position_ids: Optional[torch.FloatTensor] = None,
443+
output_attentions: Optional[bool] = None,
444+
output_hidden_states: Optional[bool] = None,
445+
return_dict: Optional[bool] = None,
446+
) -> Union[Tuple, BaseModelOutputWithPooling]:
447+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
448+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
449+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
450+
451+
batch_size = pixel_values.size(0)
452+
if patch_attention_mask is None:
453+
patch_attention_mask = torch.ones(
454+
size=(
455+
batch_size,
456+
pixel_values.size(2) // self.config.patch_size,
457+
pixel_values.size(3) // self.config.patch_size,
458+
),
459+
dtype=torch.bool,
460+
device=pixel_values.device,
461+
)
462+
463+
hidden_states = self.embeddings(
464+
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes, position_ids=position_ids
465+
)
466+
467+
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
468+
attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) if not self._use_flash_attention_2 else patch_attention_mask
469+
470+
encoder_outputs = self.encoder(
471+
inputs_embeds=hidden_states,
472+
attention_mask=attention_mask,
473+
output_attentions=output_attentions,
474+
output_hidden_states=output_hidden_states,
475+
return_dict=return_dict,
476+
)
477+
478+
last_hidden_state = encoder_outputs[0]
479+
last_hidden_state = self.post_layernorm(last_hidden_state)
480+
481+
if not return_dict:
482+
return (last_hidden_state, None) + encoder_outputs[1:]
483+
484+
return BaseModelOutputWithPooling(
485+
last_hidden_state=last_hidden_state,
486+
pooler_output=None,
487+
hidden_states=encoder_outputs.hidden_states,
488+
attentions=encoder_outputs.attentions,
489+
)
490+
491+
vpm = model.vpm
492+
vpm.embeddings.forward = types.MethodType(siglip_vis_embed_forward, vpm.embeddings)
493+
for layer in vpm.encoder.layers:
494+
layer.self_attn.forward = types.MethodType(siglip_attn_forward, layer.self_attn)
495+
vpm.forward = types.MethodType(siglip_transformer_forward, vpm)
496+
327497
pixel_values = torch.randn([1, 3, 14, 14490])
328498
patch_attn_mask = torch.zeros((1, 1, 1035), dtype=torch.bool)
329499
patch_attn_mask[0, 0, : tgt_sizes[0][0] * tgt_sizes[0][1]] = True
330-
ov_model = ov.convert_model(model.vpm, example_input={"pixel_values": pixel_values, "tgt_sizes": tgt_sizes, "patch_attention_mask": patch_attn_mask})
500+
position_ids = prepare_vis_position_ids(
501+
pixel_values, patch_attn_mask, tgt_sizes, model.config.vision_config.patch_size, model.config.vision_config.image_size // model.config.patch_size
502+
)
503+
ov_model = ov.convert_model(vpm, example_input={"pixel_values": pixel_values, "position_ids": position_ids, "patch_attention_mask": patch_attn_mask})
331504
ov.save_model(ov_model, model_dir / image_emb_path)
332505
del ov_model
333506
cleanup_torchscript_cache()
507+
gc.collect()
334508
print("✅ Image embedding model successfully converted")
335509

336510
if not (model_dir / resampler_path).exists():
@@ -343,7 +517,9 @@ def resampler_forward(self, x, pos_embed, key_padding_mask):
343517

344518
q = self.ln_q(self.query) # Q * D
345519

346-
out = self.attn(self._repeat(q, bs), x + pos_embed, x, key_padding_mask=key_padding_mask)[0] # Q * B * D # L * B * D + L * B * D
520+
q_bs = q.unsqueeze(1).repeat(1, bs, 1)
521+
522+
out = self.attn(q_bs, x + pos_embed, x, key_padding_mask=key_padding_mask)[0] # Q * B * D # L * B * D + L * B * D
347523
# out: Q * B * D
348524
x = out.permute(1, 0, 2) # B * Q * D
349525

@@ -369,6 +545,8 @@ def resampler_forward(self, x, pos_embed, key_padding_mask):
369545
ov.save_model(ov_model, model_dir / resampler_path)
370546
del ov_model
371547
cleanup_torchscript_cache()
548+
del model.resampler
549+
gc.collect()
372550
print("✅ Resampler model successfully converted")
373551

374552

@@ -380,11 +558,38 @@ def copy_llm_files(model_dir, dst_dir):
380558
shutil.copy(model_dir / llm_path.parent / "modeling_navit_siglip.py", model_dir / dst_dir / "modeling_navit_siglip.py")
381559

382560

561+
def prepare_vis_position_ids(pixel_values, patch_attention_mask, tgt_sizes, patch_size, num_patches_per_side):
562+
batch_size = pixel_values.size(0)
563+
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
564+
max_nb_patches_h, max_nb_patches_w = max_im_h // patch_size, max_im_w // patch_size
565+
boundaries = torch.arange(1 / num_patches_per_side, 1.0, 1 / num_patches_per_side)
566+
position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
567+
568+
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
569+
if tgt_sizes is not None:
570+
nb_patches_h = tgt_sizes[batch_idx][0]
571+
nb_patches_w = tgt_sizes[batch_idx][1]
572+
else:
573+
nb_patches_h = p_attn_mask[:, 0].sum()
574+
nb_patches_w = p_attn_mask[0].sum()
575+
576+
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
577+
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
578+
579+
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
580+
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
581+
582+
pos_ids = (bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w).flatten()
583+
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
584+
585+
return position_ids
586+
587+
383588
core = ov.Core()
384589

385590

386591
class OvModelForCausalLMWithEmb(GenerationMixin):
387-
def __init__(self, model_dir, device="CPU", ov_config=None, compile=True) -> None:
592+
def __init__(self, model_dir, device="CPU", ov_config=None, compile=True, slice_lm_head=True) -> None:
388593
self._supports_cache_class = False
389594
self.config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
390595
self.config.is_decoder = True
@@ -393,6 +598,8 @@ def __init__(self, model_dir, device="CPU", ov_config=None, compile=True) -> Non
393598
model_dir = Path(model_dir)
394599
self.model = core.read_model(model_dir / "language_model.xml")
395600
self.token_emb = core.read_model(model_dir / "embed_tokens.xml")
601+
if slice_lm_head:
602+
self.slice_lm_head()
396603
self.request = None
397604
self.token_emb_request = None
398605
self._device = device.upper()
@@ -402,9 +609,16 @@ def __init__(self, model_dir, device="CPU", ov_config=None, compile=True) -> Non
402609
self._past_length = None
403610
self.input_names = [input_t.get_any_name() for input_t in self.model.inputs]
404611
self.main_input_name = "input_ids"
612+
self.llm_times = []
405613
if compile:
406614
self.compile()
407615

616+
def slice_lm_head(self):
617+
manager = Manager()
618+
manager.register_pass(InsertSlice())
619+
manager.run_passes(self.model)
620+
self.model.validate_nodes_and_infer_types()
621+
408622
def compile(self):
409623
if self.request is None:
410624
self.request = core.compile_model(self.model, self._device, self.ov_config).create_infer_request()
@@ -446,6 +660,7 @@ def prepare_inputs(
446660
inputs = {}
447661
# past_key_values are not used explicitly, instead they are handled inside the model
448662
if past_key_values is None:
663+
self.llm_times = []
449664
# This is the first iteration in a sequence, reset all states
450665
if self.request is not None:
451666
self.request.reset_state()
@@ -657,20 +872,39 @@ def get_vllm_embedding(self, data):
657872
for i in range(B):
658873
patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
659874

660-
vision_batch_size = 1
875+
vision_batch_size = 32
661876
all_pixel_values = all_pixel_values
662877
if B > vision_batch_size:
663878
hs = []
664879
for i in range(0, B, vision_batch_size):
665880
start_idx = i
666881
end_idx = i + vision_batch_size
667-
tmp_hs = torch.from_numpy(
668-
self.vpm([all_pixel_values[start_idx:end_idx], patch_attn_mask[start_idx:end_idx], tgt_sizes[start_idx:end_idx]])[0]
882+
block_pxl_values = all_pixel_values[start_idx:end_idx]
883+
block_patch_attn_mask = patch_attn_mask[start_idx:end_idx]
884+
block_tgt_sizes = tgt_sizes[start_idx:end_idx]
885+
block_position_ids = prepare_vis_position_ids(
886+
block_pxl_values,
887+
block_patch_attn_mask,
888+
block_tgt_sizes,
889+
self.config.vision_config.patch_size,
890+
self.config.vision_config.image_size // self.config.patch_size,
669891
)
892+
start = time.perf_counter()
893+
tmp_hs = torch.from_numpy(self.vpm([block_pxl_values, block_patch_attn_mask, block_position_ids])[0])
894+
self.vpm_times.append(time.perf_counter() - start)
670895
hs.append(tmp_hs)
671896
vision_embedding = torch.cat(hs, dim=0)
672897
else:
673-
vision_embedding = torch.from_numpy(self.vpm([all_pixel_values, patch_attn_mask, tgt_sizes])[0])
898+
position_ids = prepare_vis_position_ids(
899+
all_pixel_values,
900+
patch_attn_mask,
901+
tgt_sizes,
902+
self.config.vision_config.patch_size,
903+
self.config.vision_config.image_size // self.config.patch_size,
904+
)
905+
start = time.perf_counter()
906+
vision_embedding = torch.from_numpy(self.vpm([all_pixel_values, patch_attn_mask, position_ids])[0])
907+
vision_embedding = torch.from_numpy(self.vpm([all_pixel_values, patch_attn_mask, position_ids])[0])
674908
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
675909

676910
start = 0
@@ -801,6 +1035,8 @@ def chat(
8011035
use_image_id=None,
8021036
**kwargs,
8031037
):
1038+
self.vpm_times = []
1039+
self.resampler_times = []
8041040
if isinstance(msgs[0], list):
8051041
batched = True
8061042
else:
@@ -844,7 +1080,6 @@ def chat(
8441080
copy_msgs = deepcopy(msgs)
8451081

8461082
assert len(msgs) > 0, "msgs is empty"
847-
assert sampling or not stream, "if use stream mode, make sure sampling=True"
8481083

8491084
if image is not None and isinstance(copy_msgs[0]["content"], str):
8501085
copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]
@@ -882,7 +1117,6 @@ def chat(
8821117
generation_config = {"top_p": 0.8, "top_k": 100, "temperature": 0.7, "do_sample": True, "repetition_penalty": 1.05}
8831118
else:
8841119
generation_config = {
885-
"num_beams": 3,
8861120
"repetition_penalty": 1.2,
8871121
}
8881122

src/cpp/include/openvino/genai/processor_config.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace ov::genai {
1414
/// preprocessor_config.json.
1515
class OPENVINO_GENAI_EXPORTS ProcessorConfig {
1616
public:
17+
size_t image_size = 980;
1718
/// @brief Dimensions of the smaller, non-overlapping patches that the
1819
/// input image is divided into before being fed into the
1920
/// transformer model. Used to divide image height and width.

0 commit comments

Comments
 (0)