@@ -574,6 +574,8 @@ def forward(
574
574
image_sizes = None ,
575
575
attention_mask = None ,
576
576
position_ids = None ,
577
+ image_bound = None ,
578
+ tgt_sizes = None ,
577
579
** kwargs ,
578
580
):
579
581
inputs_embeds , attention_mask , position_ids = self .get_multimodal_embeddings (
@@ -583,6 +585,8 @@ def forward(
583
585
attention_mask = attention_mask ,
584
586
position_ids = position_ids ,
585
587
past_key_values = past_key_values ,
588
+ image_bound = None ,
589
+ tgt_sizes = None ,
586
590
** kwargs ,
587
591
)
588
592
return self .language_model .forward (
@@ -625,6 +629,7 @@ def get_multimodal_embeddings(
625
629
)
626
630
return inputs_embeds , attention_mask , position_ids
627
631
632
+ # Adopted from https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llava/modeling_llava.py#L521
628
633
def prepare_inputs_for_generation (
629
634
self ,
630
635
input_ids ,
@@ -649,7 +654,7 @@ def prepare_inputs_for_generation(
649
654
elif past_length < input_ids .shape [1 ]:
650
655
input_ids = input_ids [:, past_length :]
651
656
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
652
- elif self .config . image_token_index in input_ids :
657
+ elif getattr ( self .config , " image_token_index" , - 1 ) in input_ids :
653
658
input_ids = input_ids [:, input_ids .shape [1 ] - 1 :]
654
659
655
660
position_ids = kwargs .get ("position_ids" , None )
@@ -673,6 +678,8 @@ def prepare_inputs_for_generation(
673
678
"attention_mask" : attention_mask ,
674
679
"pixel_values" : pixel_values ,
675
680
"image_sizes" : image_sizes ,
681
+ "image_bound" : kwargs .get ("image_bound" ),
682
+ "tgt_sizes" : kwargs .get ("tgt_sizes" ),
676
683
}
677
684
)
678
685
return model_inputs
@@ -1362,83 +1369,6 @@ def merge_vision_text_embeddings(
1362
1369
)
1363
1370
return vllm_embedding , attention_mask , position_ids
1364
1371
1365
- def prepare_inputs_for_generation (
1366
- self ,
1367
- input_ids ,
1368
- past_key_values = None ,
1369
- inputs_embeds = None ,
1370
- pixel_values = None ,
1371
- image_sizes = None ,
1372
- attention_mask = None ,
1373
- ** kwargs ,
1374
- ):
1375
- if past_key_values is not None :
1376
- past_length = self .language_model ._get_past_length (past_key_values )
1377
-
1378
- # Keep only the unprocessed tokens:
1379
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1380
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1381
- # input)
1382
- if attention_mask is not None and attention_mask .shape [1 ] > input_ids .shape [1 ]:
1383
- input_ids = input_ids [:, - (attention_mask .shape [1 ] - past_length ) :]
1384
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1385
- # input_ids based on the past_length.llava
1386
- elif past_length < input_ids .shape [1 ]:
1387
- input_ids = input_ids [:, past_length :]
1388
-
1389
- position_ids = kwargs .get ("position_ids" , None )
1390
- if attention_mask is not None and position_ids is None :
1391
- position_ids = attention_mask .long ().cumsum (- 1 ) - 1
1392
- position_ids .masked_fill_ (attention_mask == 0 , 1 )
1393
- if past_key_values :
1394
- position_ids = position_ids [:, - input_ids .shape [1 ] :]
1395
-
1396
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1397
- if inputs_embeds is not None and past_key_values is None :
1398
- model_inputs = {"inputs_embeds" : inputs_embeds }
1399
- else :
1400
- model_inputs = {"input_ids" : input_ids }
1401
-
1402
- model_inputs .update (
1403
- {
1404
- "position_ids" : position_ids ,
1405
- "past_key_values" : past_key_values ,
1406
- "use_cache" : kwargs .get ("use_cache" ),
1407
- "attention_mask" : attention_mask ,
1408
- "pixel_values" : pixel_values ,
1409
- "image_sizes" : image_sizes ,
1410
- "image_bound" : kwargs .get ("image_bound" ),
1411
- "tgt_sizes" : kwargs .get ("tgt_sizes" ),
1412
- }
1413
- )
1414
- return model_inputs
1415
-
1416
- def forward (
1417
- self ,
1418
- input_ids ,
1419
- pixel_values ,
1420
- past_key_values = None ,
1421
- inputs_embeds = None ,
1422
- image_sizes = None ,
1423
- attention_mask = None ,
1424
- position_ids = None ,
1425
- image_bound = None ,
1426
- tgt_sizes = None ,
1427
- ** kwargs ,
1428
- ):
1429
- return super ().forward (
1430
- input_ids ,
1431
- pixel_values ,
1432
- past_key_values ,
1433
- inputs_embeds ,
1434
- image_sizes ,
1435
- attention_mask ,
1436
- position_ids ,
1437
- image_bound = image_bound ,
1438
- tgt_sizes = tgt_sizes ,
1439
- ** kwargs ,
1440
- )
1441
-
1442
1372
1443
1373
MODEL_TYPE_TO_CLS_MAPPING = {
1444
1374
"llava" : _OVLlavaForCausalLM ,
0 commit comments