Skip to content

Commit bbec36a

Browse files
committed
align nanollava input with original model
1 parent a59bb41 commit bbec36a

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

optimum/intel/openvino/modeling_visual_language.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -695,8 +695,10 @@ def forward(
695695
image_grid_thw=None,
696696
video_grid_thw=None,
697697
rope_deltas=None,
698+
images=None,
698699
**kwargs,
699700
):
701+
pixel_values = pixel_values if pixel_values is not None else images
700702
inputs_embeds, attention_mask, position_ids = self.get_multimodal_embeddings(
701703
input_ids,
702704
pixel_values,
@@ -794,6 +796,9 @@ def prepare_inputs_for_generation(
794796
else:
795797
model_inputs = {"input_ids": input_ids}
796798

799+
if pixel_values is None:
800+
pixel_values = kwargs.get("images")
801+
797802
model_inputs.update(
798803
{
799804
"position_ids": position_ids,
@@ -1907,7 +1912,7 @@ def preprocess_inputs(
19071912
attention_mask = torch.ones_like(input_ids, dtype=torch.int64)
19081913
result = {"input_ids": input_ids, "attention_mask": attention_mask}
19091914
if image is not None:
1910-
result["pixel_values"] = processor(images=[image], return_tensors="pt")["pixel_values"]
1915+
result["images"] = processor(images=[image], return_tensors="pt")["pixel_values"]
19111916
return result
19121917

19131918

tests/openvino/test_modeling.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -2182,11 +2182,7 @@ def test_compare_to_transformers(self, model_arch):
21822182
ov_model.clear_requests()
21832183
self._check_device_and_request(ov_model, test_device, False)
21842184

2185-
# nanollava pixel_values input named as images
2186-
if model_arch == "nanollava":
2187-
pixel_values = transformers_inputs.pop("pixel_values", None)
2188-
transformers_inputs["images"] = pixel_values
2189-
# pytorch minicpmv is not designed to be used via forward
2185+
# pytorch minicpmv and internvl2 is not designed to be used via forward
21902186
if model_arch not in ["minicpmv", "internvl2"]:
21912187
set_seed(SEED)
21922188
ov_outputs = ov_model(**inputs)

0 commit comments

Comments
 (0)