Skip to content

Commit 73eeb9f

Browse files
committed
fix llava legacy procesing selection
1 parent 682362d commit 73eeb9f

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

optimum/intel/openvino/modeling_visual_language.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -748,13 +748,12 @@ def merge_vision_text_embeddings(
748748
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
749749
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
750750
if legacy_processing is None:
751-
legacy_processing = (
752-
not hasattr(self.config, "image_seq_length")
753-
or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length)
754-
or (input_ids.shape[-1] == 1)
751+
legacy_processing = not (hasattr(self.config, "image_seq_length") and (input_ids.shape[-1] == 1)) or (
752+
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
755753
)
756754

757755
if legacy_processing:
756+
logger.warn("LEGACY")
758757
pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
759758

760759
num_images, num_image_patches, embed_dim = image_features.shape
@@ -832,11 +831,15 @@ def merge_vision_text_embeddings(
832831
def get_multimodal_embeddings(
833832
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, past_key_values=None, **kwargs
834833
):
835-
legacy_processing = (
836-
not hasattr(self.config, "image_seq_length")
837-
or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length)
838-
or (input_ids.shape[-1] == 1 and pixel_values is not None)
839-
)
834+
legacy_processing = getattr(self, "_legacy_processing", not hasattr(self.config, "image_seq_length"))
835+
inputs_embeds = self.get_text_embeddings(input_ids, **kwargs)
836+
837+
if pixel_values is not None and not legacy_processing and past_key_values is None:
838+
legacy_processing = (input_ids == self.config.image_token_index).sum(
839+
1
840+
).max() < self.config.image_seq_length
841+
self._legacy_processing = legacy_processing
842+
840843
inputs_embeds, attention_mask, position_ids = super().get_multimodal_embeddings(
841844
input_ids, pixel_values, attention_mask, position_ids, legacy_processing=legacy_processing, **kwargs
842845
)
@@ -847,6 +850,7 @@ def get_multimodal_embeddings(
847850
return inputs_embeds, attention_mask, position_ids
848851

849852
def _filter_unattended_tokens(self, input_ids, attention_mask, past_key_values):
853+
logger.warn("LEGACY")
850854
if not self.language_model.stateful:
851855
first_layer_past_key_value = torch.from_numpy(past_key_values[0][0][:, :, :, 0])
852856
else:
@@ -954,12 +958,14 @@ def get_multimodal_embeddings(
954958
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
955959

956960
inputs_embeds = self.get_text_embeddings(input_ids, **kwargs)
961+
legacy_processing = getattr(self, "_legacy_processing", not hasattr(self.config, "image_seq_length"))
962+
963+
if pixel_values is not None and not legacy_processing and past_key_values is None:
964+
legacy_processing = (input_ids == self.config.image_token_index).sum(
965+
1
966+
).max() < self.config.image_seq_length
967+
self._legacy_processing = legacy_processing
957968

958-
legacy_processing = (
959-
not hasattr(self.config, "image_seq_length")
960-
or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length)
961-
or (input_ids.shape[-1] == 1 and pixel_values is not None)
962-
)
963969
if pixel_values is not None and pixel_values.size(0) > 0:
964970
# ! infer image_num_patches from image_sizes
965971
image_num_patches = [

0 commit comments

Comments
 (0)