Skip to content

Commit 483a55d

Browse files
committed
refactoring
1 parent bddbdd2 commit 483a55d

File tree

1 file changed

+8
-78
lines changed

1 file changed

+8
-78
lines changed

optimum/intel/openvino/modeling_visual_language.py

+8-78
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,8 @@ def forward(
574574
image_sizes=None,
575575
attention_mask=None,
576576
position_ids=None,
577+
image_bound=None,
578+
tgt_sizes=None,
577579
**kwargs,
578580
):
579581
inputs_embeds, attention_mask, position_ids = self.get_multimodal_embeddings(
@@ -583,6 +585,8 @@ def forward(
583585
attention_mask=attention_mask,
584586
position_ids=position_ids,
585587
past_key_values=past_key_values,
588+
image_bound=None,
589+
tgt_sizes=None,
586590
**kwargs,
587591
)
588592
return self.language_model.forward(
@@ -625,6 +629,7 @@ def get_multimodal_embeddings(
625629
)
626630
return inputs_embeds, attention_mask, position_ids
627631

632+
# Adopted from https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llava/modeling_llava.py#L521
628633
def prepare_inputs_for_generation(
629634
self,
630635
input_ids,
@@ -649,7 +654,7 @@ def prepare_inputs_for_generation(
649654
elif past_length < input_ids.shape[1]:
650655
input_ids = input_ids[:, past_length:]
651656
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
652-
elif self.config.image_token_index in input_ids:
657+
elif getattr(self.config, "image_token_index", -1) in input_ids:
653658
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
654659

655660
position_ids = kwargs.get("position_ids", None)
@@ -673,6 +678,8 @@ def prepare_inputs_for_generation(
673678
"attention_mask": attention_mask,
674679
"pixel_values": pixel_values,
675680
"image_sizes": image_sizes,
681+
"image_bound": kwargs.get("image_bound"),
682+
"tgt_sizes": kwargs.get("tgt_sizes"),
676683
}
677684
)
678685
return model_inputs
@@ -1362,83 +1369,6 @@ def merge_vision_text_embeddings(
13621369
)
13631370
return vllm_embedding, attention_mask, position_ids
13641371

1365-
def prepare_inputs_for_generation(
1366-
self,
1367-
input_ids,
1368-
past_key_values=None,
1369-
inputs_embeds=None,
1370-
pixel_values=None,
1371-
image_sizes=None,
1372-
attention_mask=None,
1373-
**kwargs,
1374-
):
1375-
if past_key_values is not None:
1376-
past_length = self.language_model._get_past_length(past_key_values)
1377-
1378-
# Keep only the unprocessed tokens:
1379-
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1380-
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1381-
# input)
1382-
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1383-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1384-
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1385-
# input_ids based on the past_length.llava
1386-
elif past_length < input_ids.shape[1]:
1387-
input_ids = input_ids[:, past_length:]
1388-
1389-
position_ids = kwargs.get("position_ids", None)
1390-
if attention_mask is not None and position_ids is None:
1391-
position_ids = attention_mask.long().cumsum(-1) - 1
1392-
position_ids.masked_fill_(attention_mask == 0, 1)
1393-
if past_key_values:
1394-
position_ids = position_ids[:, -input_ids.shape[1] :]
1395-
1396-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1397-
if inputs_embeds is not None and past_key_values is None:
1398-
model_inputs = {"inputs_embeds": inputs_embeds}
1399-
else:
1400-
model_inputs = {"input_ids": input_ids}
1401-
1402-
model_inputs.update(
1403-
{
1404-
"position_ids": position_ids,
1405-
"past_key_values": past_key_values,
1406-
"use_cache": kwargs.get("use_cache"),
1407-
"attention_mask": attention_mask,
1408-
"pixel_values": pixel_values,
1409-
"image_sizes": image_sizes,
1410-
"image_bound": kwargs.get("image_bound"),
1411-
"tgt_sizes": kwargs.get("tgt_sizes"),
1412-
}
1413-
)
1414-
return model_inputs
1415-
1416-
def forward(
1417-
self,
1418-
input_ids,
1419-
pixel_values,
1420-
past_key_values=None,
1421-
inputs_embeds=None,
1422-
image_sizes=None,
1423-
attention_mask=None,
1424-
position_ids=None,
1425-
image_bound=None,
1426-
tgt_sizes=None,
1427-
**kwargs,
1428-
):
1429-
return super().forward(
1430-
input_ids,
1431-
pixel_values,
1432-
past_key_values,
1433-
inputs_embeds,
1434-
image_sizes,
1435-
attention_mask,
1436-
position_ids,
1437-
image_bound=image_bound,
1438-
tgt_sizes=tgt_sizes,
1439-
**kwargs,
1440-
)
1441-
14421372

14431373
MODEL_TYPE_TO_CLS_MAPPING = {
14441374
"llava": _OVLlavaForCausalLM,

0 commit comments

Comments
 (0)