@@ -748,13 +748,12 @@ def merge_vision_text_embeddings(
748
748
image_features = torch .from_numpy (vision_embeds ) if isinstance (vision_embeds , np .ndarray ) else vision_embeds
749
749
inputs_embeds = torch .from_numpy (inputs_embeds ) if isinstance (inputs_embeds , np .ndarray ) else inputs_embeds
750
750
if legacy_processing is None :
751
- legacy_processing = (
752
- not hasattr (self .config , "image_seq_length" )
753
- or ((input_ids == self .config .image_token_index ).sum (1 ).max () < self .config .image_seq_length )
754
- or (input_ids .shape [- 1 ] == 1 )
751
+ legacy_processing = not (hasattr (self .config , "image_seq_length" ) and (input_ids .shape [- 1 ] == 1 )) or (
752
+ (input_ids == self .config .image_token_index ).sum (1 ).max () < self .config .image_seq_length
755
753
)
756
754
757
755
if legacy_processing :
756
+ logger .warn ("LEGACY" )
758
757
pad_token_id = self .config .pad_token_id if self .config .pad_token_id is not None else - 1
759
758
760
759
num_images , num_image_patches , embed_dim = image_features .shape
@@ -832,11 +831,15 @@ def merge_vision_text_embeddings(
832
831
def get_multimodal_embeddings (
833
832
self , input_ids , pixel_values = None , attention_mask = None , position_ids = None , past_key_values = None , ** kwargs
834
833
):
835
- legacy_processing = (
836
- not hasattr (self .config , "image_seq_length" )
837
- or ((input_ids == self .config .image_token_index ).sum (1 ).max () < self .config .image_seq_length )
838
- or (input_ids .shape [- 1 ] == 1 and pixel_values is not None )
839
- )
834
+ legacy_processing = getattr (self , "_legacy_processing" , not hasattr (self .config , "image_seq_length" ))
835
+ inputs_embeds = self .get_text_embeddings (input_ids , ** kwargs )
836
+
837
+ if pixel_values is not None and not legacy_processing and past_key_values is None :
838
+ legacy_processing = (input_ids == self .config .image_token_index ).sum (
839
+ 1
840
+ ).max () < self .config .image_seq_length
841
+ self ._legacy_processing = legacy_processing
842
+
840
843
inputs_embeds , attention_mask , position_ids = super ().get_multimodal_embeddings (
841
844
input_ids , pixel_values , attention_mask , position_ids , legacy_processing = legacy_processing , ** kwargs
842
845
)
@@ -847,6 +850,7 @@ def get_multimodal_embeddings(
847
850
return inputs_embeds , attention_mask , position_ids
848
851
849
852
def _filter_unattended_tokens (self , input_ids , attention_mask , past_key_values ):
853
+ logger .warn ("LEGACY" )
850
854
if not self .language_model .stateful :
851
855
first_layer_past_key_value = torch .from_numpy (past_key_values [0 ][0 ][:, :, :, 0 ])
852
856
else :
@@ -954,12 +958,14 @@ def get_multimodal_embeddings(
954
958
from transformers .models .llava_next .modeling_llava_next import image_size_to_num_patches
955
959
956
960
inputs_embeds = self .get_text_embeddings (input_ids , ** kwargs )
961
+ legacy_processing = getattr (self , "_legacy_processing" , not hasattr (self .config , "image_seq_length" ))
962
+
963
+ if pixel_values is not None and not legacy_processing and past_key_values is None :
964
+ legacy_processing = (input_ids == self .config .image_token_index ).sum (
965
+ 1
966
+ ).max () < self .config .image_seq_length
967
+ self ._legacy_processing = legacy_processing
957
968
958
- legacy_processing = (
959
- not hasattr (self .config , "image_seq_length" )
960
- or ((input_ids == self .config .image_token_index ).sum (1 ).max () < self .config .image_seq_length )
961
- or (input_ids .shape [- 1 ] == 1 and pixel_values is not None )
962
- )
963
969
if pixel_values is not None and pixel_values .size (0 ) > 0 :
964
970
# ! infer image_num_patches from image_sizes
965
971
image_num_patches = [
0 commit comments