9
9
from transformers import AutoModel , AutoTokenizer , AutoProcessor , TextIteratorStreamer
10
10
from transformers .generation import GenerationMixin
11
11
from transformers import AutoConfig , GenerationConfig
12
- from transformers .modeling_outputs import CausalLMOutputWithPast
12
+ from transformers .modeling_outputs import CausalLMOutputWithPast , BaseModelOutputWithPooling
13
+ from transformers .modeling_attn_mask_utils import _prepare_4d_attention_mask
13
14
from pathlib import Path
14
15
from huggingface_hub import snapshot_download
15
16
import types
16
- from typing import Optional , Tuple , List
17
+ from typing import Optional , Tuple , List , Union
17
18
from openvino .runtime import opset13
18
19
import openvino as ov
19
20
import openvino_tokenizers
20
21
import numpy as np
21
22
import gc
23
+ from openvino .runtime .passes import Manager , MatcherPass , WrapType , Matcher
24
+ import time
22
25
23
26
text_emb_path = Path ("embed_tokens.xml" )
24
27
image_emb_path = Path ("image_encoder.xml" )
25
28
resampler_path = Path ("resampler.xml" )
26
29
llm_path = Path ("language_model.xml" )
27
30
31
+ class InsertSlice (MatcherPass ):
32
+ def __init__ (self ):
33
+ MatcherPass .__init__ (self )
34
+ self .model_changed = False
35
+
36
+ param = WrapType ("opset10.Result" )
37
+
38
+ def callback (matcher : Matcher ) -> bool :
39
+ root = matcher .get_match_root ()
40
+ if root is None :
41
+ return False
42
+ if len (root .get_output_partial_shape (0 )) == 3 :
43
+ parent = root .input_value (0 ).get_node ()
44
+ grand_parent = parent .input_value (0 ).get_node ()
45
+
46
+ grand_parent_output = parent .input (0 ).get_source_output ()
47
+ consumers = grand_parent_output .get_target_inputs ()
48
+ start = np .array ([0 , - 1 , 0 ], dtype = np .int32 )
49
+ stop = np .array ([1 , - 2 , grand_parent_output .get_partial_shape ()[- 1 ].get_length ()], dtype = np .int32 )
50
+ step = np .array ([1 , - 1 , 1 ], dtype = np .int32 )
51
+ axes = np .array ([0 , 1 , 2 ], dtype = np .int32 )
52
+ slice = opset13 .slice (grand_parent , start , stop , step , axes , name = "inserted_slice" )
53
+ for consumer in consumers :
54
+ consumer .replace_source_output (slice .output (0 ))
55
+ self .model_changed = True
56
+ # Use new operation for additional matching
57
+ self .register_new_node (slice )
58
+ print ("applied slice for lm head" )
59
+
60
+ return True
61
+
62
+ self .register_matcher (Matcher (param , "InsertSlice" ), callback )
63
+
28
64
29
65
def model_has_state (ov_model : ov .Model ):
30
66
return len (ov_model .get_sinks ()) > 0
@@ -324,13 +360,151 @@ def convert_vision_encoder(model, model_dir):
324
360
tgt_sizes = torch .tensor ([[23 , 45 ]])
325
361
if not (model_dir / image_emb_path ).exists ():
326
362
print ("⌛ Convert Image embedding model" )
363
+ def siglip_vis_embed_forward (
364
+ self ,
365
+ pixel_values : torch .FloatTensor ,
366
+ patch_attention_mask : torch .BoolTensor ,
367
+ tgt_sizes : Optional [torch .IntTensor ] = None ,
368
+ position_ids : Optional [torch .FloatTensor ] = None ,
369
+ ) -> torch .Tensor :
370
+ patch_embeds = self .patch_embedding (pixel_values )
371
+ embeddings = patch_embeds .flatten (2 ).transpose (1 , 2 )
372
+
373
+ if position_ids is None :
374
+ batch_size = pixel_values .size (0 )
375
+ max_im_h , max_im_w = pixel_values .size (2 ), pixel_values .size (3 )
376
+ max_nb_patches_h , max_nb_patches_w = max_im_h // self .patch_size , max_im_w // self .patch_size
377
+ boundaries = torch .arange (1 / self .num_patches_per_side , 1.0 , 1 / self .num_patches_per_side )
378
+ position_ids = torch .full (
379
+ size = (
380
+ batch_size ,
381
+ max_nb_patches_h * max_nb_patches_w ,
382
+ ),
383
+ fill_value = 0 ,
384
+ )
385
+
386
+ for batch_idx , p_attn_mask in enumerate (patch_attention_mask ):
387
+ if tgt_sizes is not None :
388
+ nb_patches_h = tgt_sizes [batch_idx ][0 ]
389
+ nb_patches_w = tgt_sizes [batch_idx ][1 ]
390
+ else :
391
+ nb_patches_h = p_attn_mask [:, 0 ].sum ()
392
+ nb_patches_w = p_attn_mask [0 ].sum ()
393
+
394
+ fractional_coords_h = torch .arange (0 , 1 - 1e-6 , 1 / nb_patches_h )
395
+ fractional_coords_w = torch .arange (0 , 1 - 1e-6 , 1 / nb_patches_w )
396
+
397
+ bucket_coords_h = torch .bucketize (fractional_coords_h , boundaries , right = True )
398
+ bucket_coords_w = torch .bucketize (fractional_coords_w , boundaries , right = True )
399
+
400
+ pos_ids = (bucket_coords_h [:, None ] * self .num_patches_per_side + bucket_coords_w ).flatten ()
401
+ position_ids [batch_idx ][p_attn_mask .view (- 1 ).cpu ()] = pos_ids
402
+
403
+ position_ids = position_ids .to (self .position_embedding .weight .device )
404
+
405
+ embeddings = embeddings + self .position_embedding (position_ids )
406
+ return embeddings
407
+
408
+ def siglip_attn_forward (
409
+ self ,
410
+ hidden_states : torch .Tensor ,
411
+ attention_mask : Optional [torch .Tensor ] = None ,
412
+ output_attentions : Optional [bool ] = False ,
413
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
414
+ """Input shape: Batch x Time x Channel"""
415
+
416
+ batch_size , q_len , _ = hidden_states .size ()
417
+
418
+ query_states = self .q_proj (hidden_states )
419
+ key_states = self .k_proj (hidden_states )
420
+ value_states = self .v_proj (hidden_states )
421
+
422
+ query_states = query_states .view (batch_size , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
423
+ key_states = key_states .view (batch_size , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
424
+ value_states = value_states .view (batch_size , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
425
+
426
+ attn_output = torch .nn .functional .scaled_dot_product_attention (
427
+ query_states , key_states , value_states , attention_mask , is_causal = attention_mask is None
428
+ )
429
+
430
+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
431
+ attn_output = attn_output .reshape (batch_size , q_len , self .embed_dim )
432
+
433
+ attn_output = self .out_proj (attn_output )
434
+
435
+ return attn_output , None
436
+
437
+ def siglip_transformer_forward (
438
+ self ,
439
+ pixel_values ,
440
+ patch_attention_mask : Optional [torch .BoolTensor ] = None ,
441
+ tgt_sizes : Optional [torch .IntTensor ] = None ,
442
+ position_ids : Optional [torch .FloatTensor ] = None ,
443
+ output_attentions : Optional [bool ] = None ,
444
+ output_hidden_states : Optional [bool ] = None ,
445
+ return_dict : Optional [bool ] = None ,
446
+ ) -> Union [Tuple , BaseModelOutputWithPooling ]:
447
+ output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
448
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
449
+ return_dict = return_dict if return_dict is not None else self .config .use_return_dict
450
+
451
+ batch_size = pixel_values .size (0 )
452
+ if patch_attention_mask is None :
453
+ patch_attention_mask = torch .ones (
454
+ size = (
455
+ batch_size ,
456
+ pixel_values .size (2 ) // self .config .patch_size ,
457
+ pixel_values .size (3 ) // self .config .patch_size ,
458
+ ),
459
+ dtype = torch .bool ,
460
+ device = pixel_values .device ,
461
+ )
462
+
463
+ hidden_states = self .embeddings (
464
+ pixel_values = pixel_values , patch_attention_mask = patch_attention_mask , tgt_sizes = tgt_sizes , position_ids = position_ids
465
+ )
466
+
467
+ patch_attention_mask = patch_attention_mask .view (batch_size , - 1 )
468
+ attention_mask = _prepare_4d_attention_mask (patch_attention_mask , hidden_states .dtype ) if not self ._use_flash_attention_2 else patch_attention_mask
469
+
470
+ encoder_outputs = self .encoder (
471
+ inputs_embeds = hidden_states ,
472
+ attention_mask = attention_mask ,
473
+ output_attentions = output_attentions ,
474
+ output_hidden_states = output_hidden_states ,
475
+ return_dict = return_dict ,
476
+ )
477
+
478
+ last_hidden_state = encoder_outputs [0 ]
479
+ last_hidden_state = self .post_layernorm (last_hidden_state )
480
+
481
+ if not return_dict :
482
+ return (last_hidden_state , None ) + encoder_outputs [1 :]
483
+
484
+ return BaseModelOutputWithPooling (
485
+ last_hidden_state = last_hidden_state ,
486
+ pooler_output = None ,
487
+ hidden_states = encoder_outputs .hidden_states ,
488
+ attentions = encoder_outputs .attentions ,
489
+ )
490
+
491
+ vpm = model .vpm
492
+ vpm .embeddings .forward = types .MethodType (siglip_vis_embed_forward , vpm .embeddings )
493
+ for layer in vpm .encoder .layers :
494
+ layer .self_attn .forward = types .MethodType (siglip_attn_forward , layer .self_attn )
495
+ vpm .forward = types .MethodType (siglip_transformer_forward , vpm )
496
+
327
497
pixel_values = torch .randn ([1 , 3 , 14 , 14490 ])
328
498
patch_attn_mask = torch .zeros ((1 , 1 , 1035 ), dtype = torch .bool )
329
499
patch_attn_mask [0 , 0 , : tgt_sizes [0 ][0 ] * tgt_sizes [0 ][1 ]] = True
330
- ov_model = ov .convert_model (model .vpm , example_input = {"pixel_values" : pixel_values , "tgt_sizes" : tgt_sizes , "patch_attention_mask" : patch_attn_mask })
500
+ position_ids = prepare_vis_position_ids (
501
+ pixel_values , patch_attn_mask , tgt_sizes , model .config .vision_config .patch_size , model .config .vision_config .image_size // model .config .patch_size
502
+ )
503
+ ov_model = ov .convert_model (vpm , example_input = {"pixel_values" : pixel_values , "position_ids" : position_ids , "patch_attention_mask" : patch_attn_mask })
331
504
ov .save_model (ov_model , model_dir / image_emb_path )
332
505
del ov_model
333
506
cleanup_torchscript_cache ()
507
+ gc .collect ()
334
508
print ("✅ Image embedding model successfully converted" )
335
509
336
510
if not (model_dir / resampler_path ).exists ():
@@ -343,7 +517,9 @@ def resampler_forward(self, x, pos_embed, key_padding_mask):
343
517
344
518
q = self .ln_q (self .query ) # Q * D
345
519
346
- out = self .attn (self ._repeat (q , bs ), x + pos_embed , x , key_padding_mask = key_padding_mask )[0 ] # Q * B * D # L * B * D + L * B * D
520
+ q_bs = q .unsqueeze (1 ).repeat (1 , bs , 1 )
521
+
522
+ out = self .attn (q_bs , x + pos_embed , x , key_padding_mask = key_padding_mask )[0 ] # Q * B * D # L * B * D + L * B * D
347
523
# out: Q * B * D
348
524
x = out .permute (1 , 0 , 2 ) # B * Q * D
349
525
@@ -369,6 +545,8 @@ def resampler_forward(self, x, pos_embed, key_padding_mask):
369
545
ov .save_model (ov_model , model_dir / resampler_path )
370
546
del ov_model
371
547
cleanup_torchscript_cache ()
548
+ del model .resampler
549
+ gc .collect ()
372
550
print ("✅ Resampler model successfully converted" )
373
551
374
552
@@ -380,11 +558,38 @@ def copy_llm_files(model_dir, dst_dir):
380
558
shutil .copy (model_dir / llm_path .parent / "modeling_navit_siglip.py" , model_dir / dst_dir / "modeling_navit_siglip.py" )
381
559
382
560
561
+ def prepare_vis_position_ids (pixel_values , patch_attention_mask , tgt_sizes , patch_size , num_patches_per_side ):
562
+ batch_size = pixel_values .size (0 )
563
+ max_im_h , max_im_w = pixel_values .size (2 ), pixel_values .size (3 )
564
+ max_nb_patches_h , max_nb_patches_w = max_im_h // patch_size , max_im_w // patch_size
565
+ boundaries = torch .arange (1 / num_patches_per_side , 1.0 , 1 / num_patches_per_side )
566
+ position_ids = torch .full (size = (batch_size , max_nb_patches_h * max_nb_patches_w ), fill_value = 0 )
567
+
568
+ for batch_idx , p_attn_mask in enumerate (patch_attention_mask ):
569
+ if tgt_sizes is not None :
570
+ nb_patches_h = tgt_sizes [batch_idx ][0 ]
571
+ nb_patches_w = tgt_sizes [batch_idx ][1 ]
572
+ else :
573
+ nb_patches_h = p_attn_mask [:, 0 ].sum ()
574
+ nb_patches_w = p_attn_mask [0 ].sum ()
575
+
576
+ fractional_coords_h = torch .arange (0 , 1 - 1e-6 , 1 / nb_patches_h )
577
+ fractional_coords_w = torch .arange (0 , 1 - 1e-6 , 1 / nb_patches_w )
578
+
579
+ bucket_coords_h = torch .bucketize (fractional_coords_h , boundaries , right = True )
580
+ bucket_coords_w = torch .bucketize (fractional_coords_w , boundaries , right = True )
581
+
582
+ pos_ids = (bucket_coords_h [:, None ] * num_patches_per_side + bucket_coords_w ).flatten ()
583
+ position_ids [batch_idx ][p_attn_mask .view (- 1 ).cpu ()] = pos_ids
584
+
585
+ return position_ids
586
+
587
+
383
588
core = ov .Core ()
384
589
385
590
386
591
class OvModelForCausalLMWithEmb (GenerationMixin ):
387
- def __init__ (self , model_dir , device = "CPU" , ov_config = None , compile = True ) -> None :
592
+ def __init__ (self , model_dir , device = "CPU" , ov_config = None , compile = True , slice_lm_head = True ) -> None :
388
593
self ._supports_cache_class = False
389
594
self .config = AutoConfig .from_pretrained (model_dir , trust_remote_code = True )
390
595
self .config .is_decoder = True
@@ -393,6 +598,8 @@ def __init__(self, model_dir, device="CPU", ov_config=None, compile=True) -> Non
393
598
model_dir = Path (model_dir )
394
599
self .model = core .read_model (model_dir / "language_model.xml" )
395
600
self .token_emb = core .read_model (model_dir / "embed_tokens.xml" )
601
+ if slice_lm_head :
602
+ self .slice_lm_head ()
396
603
self .request = None
397
604
self .token_emb_request = None
398
605
self ._device = device .upper ()
@@ -402,9 +609,16 @@ def __init__(self, model_dir, device="CPU", ov_config=None, compile=True) -> Non
402
609
self ._past_length = None
403
610
self .input_names = [input_t .get_any_name () for input_t in self .model .inputs ]
404
611
self .main_input_name = "input_ids"
612
+ self .llm_times = []
405
613
if compile :
406
614
self .compile ()
407
615
616
+ def slice_lm_head (self ):
617
+ manager = Manager ()
618
+ manager .register_pass (InsertSlice ())
619
+ manager .run_passes (self .model )
620
+ self .model .validate_nodes_and_infer_types ()
621
+
408
622
def compile (self ):
409
623
if self .request is None :
410
624
self .request = core .compile_model (self .model , self ._device , self .ov_config ).create_infer_request ()
@@ -446,6 +660,7 @@ def prepare_inputs(
446
660
inputs = {}
447
661
# past_key_values are not used explicitly, instead they are handled inside the model
448
662
if past_key_values is None :
663
+ self .llm_times = []
449
664
# This is the first iteration in a sequence, reset all states
450
665
if self .request is not None :
451
666
self .request .reset_state ()
@@ -657,20 +872,39 @@ def get_vllm_embedding(self, data):
657
872
for i in range (B ):
658
873
patch_attn_mask [i , 0 , : tgt_sizes [i ][0 ] * tgt_sizes [i ][1 ]] = True
659
874
660
- vision_batch_size = 1
875
+ vision_batch_size = 32
661
876
all_pixel_values = all_pixel_values
662
877
if B > vision_batch_size :
663
878
hs = []
664
879
for i in range (0 , B , vision_batch_size ):
665
880
start_idx = i
666
881
end_idx = i + vision_batch_size
667
- tmp_hs = torch .from_numpy (
668
- self .vpm ([all_pixel_values [start_idx :end_idx ], patch_attn_mask [start_idx :end_idx ], tgt_sizes [start_idx :end_idx ]])[0 ]
882
+ block_pxl_values = all_pixel_values [start_idx :end_idx ]
883
+ block_patch_attn_mask = patch_attn_mask [start_idx :end_idx ]
884
+ block_tgt_sizes = tgt_sizes [start_idx :end_idx ]
885
+ block_position_ids = prepare_vis_position_ids (
886
+ block_pxl_values ,
887
+ block_patch_attn_mask ,
888
+ block_tgt_sizes ,
889
+ self .config .vision_config .patch_size ,
890
+ self .config .vision_config .image_size // self .config .patch_size ,
669
891
)
892
+ start = time .perf_counter ()
893
+ tmp_hs = torch .from_numpy (self .vpm ([block_pxl_values , block_patch_attn_mask , block_position_ids ])[0 ])
894
+ self .vpm_times .append (time .perf_counter () - start )
670
895
hs .append (tmp_hs )
671
896
vision_embedding = torch .cat (hs , dim = 0 )
672
897
else :
673
- vision_embedding = torch .from_numpy (self .vpm ([all_pixel_values , patch_attn_mask , tgt_sizes ])[0 ])
898
+ position_ids = prepare_vis_position_ids (
899
+ all_pixel_values ,
900
+ patch_attn_mask ,
901
+ tgt_sizes ,
902
+ self .config .vision_config .patch_size ,
903
+ self .config .vision_config .image_size // self .config .patch_size ,
904
+ )
905
+ start = time .perf_counter ()
906
+ vision_embedding = torch .from_numpy (self .vpm ([all_pixel_values , patch_attn_mask , position_ids ])[0 ])
907
+ vision_embedding = torch .from_numpy (self .vpm ([all_pixel_values , patch_attn_mask , position_ids ])[0 ])
674
908
vision_embedding = self .resampler (vision_embedding , tgt_sizes )
675
909
676
910
start = 0
@@ -801,6 +1035,8 @@ def chat(
801
1035
use_image_id = None ,
802
1036
** kwargs ,
803
1037
):
1038
+ self .vpm_times = []
1039
+ self .resampler_times = []
804
1040
if isinstance (msgs [0 ], list ):
805
1041
batched = True
806
1042
else :
@@ -844,7 +1080,6 @@ def chat(
844
1080
copy_msgs = deepcopy (msgs )
845
1081
846
1082
assert len (msgs ) > 0 , "msgs is empty"
847
- assert sampling or not stream , "if use stream mode, make sure sampling=True"
848
1083
849
1084
if image is not None and isinstance (copy_msgs [0 ]["content" ], str ):
850
1085
copy_msgs [0 ]["content" ] = [image , copy_msgs [0 ]["content" ]]
@@ -882,7 +1117,6 @@ def chat(
882
1117
generation_config = {"top_p" : 0.8 , "top_k" : 100 , "temperature" : 0.7 , "do_sample" : True , "repetition_penalty" : 1.05 }
883
1118
else :
884
1119
generation_config = {
885
- "num_beams" : 3 ,
886
1120
"repetition_penalty" : 1.2 ,
887
1121
}
888
1122
0 commit comments