@@ -1204,9 +1204,10 @@ def merge_vision_text_embeddings(
1204
1204
attention_mask ,
1205
1205
position_ids = None ,
1206
1206
legacy_processing = False ,
1207
+ image_token_index = None ,
1207
1208
** kwargs ,
1208
1209
):
1209
- image_token_index = self .config .image_token_index
1210
+ image_token_index = self .config .image_token_index if image_token_index is None else image_token_index
1210
1211
image_features = torch .from_numpy (vision_embeds ) if isinstance (vision_embeds , np .ndarray ) else vision_embeds
1211
1212
inputs_embeds = torch .from_numpy (inputs_embeds ) if isinstance (inputs_embeds , np .ndarray ) else inputs_embeds
1212
1213
@@ -1235,7 +1236,7 @@ def merge_vision_text_embeddings(
1235
1236
1236
1237
# Whether to turn off right padding
1237
1238
# 1. Create a mask to know where special image tokens are
1238
- special_image_token_mask = input_ids == image_token_index
1239
+ special_image_token_mask = torch . tensor ( input_ids == image_token_index )
1239
1240
# special_image_token_mask: [bsz, seqlen]
1240
1241
num_special_image_tokens = torch .sum (special_image_token_mask , dim = - 1 )
1241
1242
# num_special_image_tokens: [bsz]
@@ -1328,7 +1329,7 @@ def merge_vision_text_embeddings(
1328
1329
final_attention_mask |= image_to_overwrite
1329
1330
position_ids = (final_attention_mask .cumsum (- 1 ) - 1 ).masked_fill_ ((final_attention_mask == 0 ), 1 )
1330
1331
else :
1331
- special_image_mask = ( input_ids == self . config . image_token_index ).unsqueeze (- 1 ).expand_as (inputs_embeds )
1332
+ special_image_mask = torch . tensor (( input_ids == image_token_index ) ).unsqueeze (- 1 ).expand_as (inputs_embeds )
1332
1333
image_features = image_features .to (inputs_embeds .dtype )
1333
1334
final_embedding = inputs_embeds .masked_scatter (special_image_mask , image_features )
1334
1335
final_attention_mask = attention_mask
@@ -1432,28 +1433,43 @@ def add_video_features(
1432
1433
legacy_processing ,
1433
1434
** kwargs ,
1434
1435
):
1436
+ # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/llava_next_video/modeling_llava_next_video.py#L732-L751
1435
1437
video_features = self .get_video_features (pixel_values_videos , input_ids )
1436
1438
if video_features is not None :
1437
1439
if legacy_processing :
1438
- raise ValueError ("Video processing supported only for transformers>=4.45 preprocessing." )
1439
- inputs_embeds = torch .from_numpy (inputs_embeds ) if isinstance (inputs_embeds , np .ndarray ) else inputs_embeds
1440
- video_features = [feature .flatten (0 , 1 ) for feature in video_features ]
1441
- video_feature_lens = [feature .size (0 ) for feature in video_features ]
1442
- video_features = torch .cat (video_features , dim = 0 )
1443
- video_feature_lens = torch .tensor (video_feature_lens , dtype = torch .long , device = video_features .device )
1444
-
1445
- special_image_mask = (input_ids == self .config .video_token_index ).unsqueeze (- 1 )
1446
- special_image_mask = special_image_mask .expand_as (inputs_embeds )
1447
- if inputs_embeds [special_image_mask ].numel () != video_features .numel ():
1448
- n_video_tokens = (input_ids == self .config .video_token_index ).sum ().item ()
1449
- n_video_features = video_features .shape [0 ]
1450
- raise ValueError (
1451
- f"Video features and video tokens do not match: tokens: { n_video_tokens } , features { n_video_features } "
1440
+ video_feature_lens = [feature .size (0 ) for feature in video_features ]
1441
+ inputs_embeds , attention_mask , position_ids = self .merge_vision_text_embeddings (
1442
+ video_features ,
1443
+ inputs_embeds ,
1444
+ video_feature_lens ,
1445
+ input_ids ,
1446
+ attention_mask ,
1447
+ position_ids ,
1448
+ legacy_processing ,
1449
+ self .config .video_token_index ,
1452
1450
)
1453
- inputs_embeds = inputs_embeds .masked_scatter (special_image_mask , video_features )
1451
+ else :
1452
+ inputs_embeds = (
1453
+ torch .from_numpy (inputs_embeds ) if isinstance (inputs_embeds , np .ndarray ) else inputs_embeds
1454
+ )
1455
+ video_features = [feature .flatten (0 , 1 ) for feature in video_features ]
1456
+ video_feature_lens = [feature .size (0 ) for feature in video_features ]
1457
+ video_features = torch .cat (video_features , dim = 0 )
1458
+ video_feature_lens = torch .tensor (video_feature_lens , dtype = torch .long , device = video_features .device )
1459
+
1460
+ special_image_mask = (input_ids == self .config .video_token_index ).unsqueeze (- 1 )
1461
+ special_image_mask = special_image_mask .expand_as (inputs_embeds )
1462
+ if inputs_embeds [special_image_mask ].numel () != video_features .numel ():
1463
+ n_video_tokens = (input_ids == self .config .video_token_index ).sum ().item ()
1464
+ n_video_features = video_features .shape [0 ]
1465
+ raise ValueError (
1466
+ f"Video features and video tokens do not match: tokens: { n_video_tokens } , features { n_video_features } "
1467
+ )
1468
+ inputs_embeds = inputs_embeds .masked_scatter (special_image_mask , video_features )
1454
1469
return inputs_embeds , attention_mask , position_ids
1455
1470
1456
1471
def get_video_features (self , pixel_values , input_ids = None , ** kwargs ):
1472
+ # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/llava_next_video/modeling_llava_next_video.py#L835
1457
1473
if input_ids is not None and input_ids .shape [1 ] == 1 :
1458
1474
return None
1459
1475
batch_size , frames , channels , height , width = pixel_values .shape
0 commit comments