@@ -1360,6 +1360,80 @@ def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
1360
1360
image_features = self .multi_modal_projector (image_features )
1361
1361
return image_features
1362
1362
1363
+ def pack_image_features (self , image_features , image_sizes , image_newline = None ):
1364
+ """
1365
+ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
1366
+
1367
+ Args:
1368
+ image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
1369
+ List of image feature tensor, each contains all the visual feature of all patches.
1370
+ image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
1371
+ Actual image size of each images (H, W).
1372
+ vision_feature_select_strategy (`str`)
1373
+ The feature selection strategy used to select the vision feature from the vision backbone.
1374
+ image_newline (`torch.Tensor` of shape `(embed_dim)`)
1375
+ New line embedding vector.
1376
+ Returns:
1377
+ image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
1378
+ feature_lens (`List[int]`)
1379
+ token length of each image in image_features
1380
+ """
1381
+ from transformers .models .llava_next_video .modeling_llava_next_video import (
1382
+ get_anyres_image_grid_shape ,
1383
+ unpad_image ,
1384
+ )
1385
+
1386
+ new_image_features = []
1387
+ feature_lens = []
1388
+ vision_feature_select_strategy = self .config .vision_feature_select_strategy
1389
+ for image_idx , image_feature in enumerate (image_features ):
1390
+ if image_feature .shape [0 ] > 1 :
1391
+ base_image_feature = image_feature [0 ]
1392
+ image_feature = image_feature [1 :]
1393
+ height = width = self .config .vision_config .image_size // self .config .vision_config .patch_size
1394
+
1395
+ num_patch_height , num_patch_width = get_anyres_image_grid_shape (
1396
+ image_sizes [image_idx ],
1397
+ self .config .image_grid_pinpoints ,
1398
+ self .config .vision_config .image_size ,
1399
+ )
1400
+
1401
+ if (
1402
+ np .prod (image_feature .shape ) % (num_patch_height * num_patch_width * height * width ) != 0
1403
+ and vision_feature_select_strategy == "default"
1404
+ ):
1405
+ logger .warning_once (
1406
+ "Image feature shape does not line up with the provided patch size. "
1407
+ "You may be using the `default` vision_feature_select_strategy with a"
1408
+ " visual encoder that does not have CLS."
1409
+ )
1410
+
1411
+ image_feature = image_feature .view (num_patch_height , num_patch_width , height , width , - 1 )
1412
+ image_feature = image_feature .permute (4 , 0 , 2 , 1 , 3 ).contiguous ()
1413
+ image_feature = image_feature .flatten (1 , 2 ).flatten (2 , 3 )
1414
+ image_feature = unpad_image (image_feature , image_sizes [image_idx ])
1415
+ if image_newline is not None :
1416
+ image_feature = torch .cat (
1417
+ (
1418
+ image_feature ,
1419
+ image_newline [:, None , None ]
1420
+ .expand (* image_feature .shape [:- 1 ], 1 )
1421
+ .to (image_feature .device , image_feature .dtype ),
1422
+ ),
1423
+ dim = - 1 ,
1424
+ )
1425
+ image_feature = image_feature .flatten (1 , 2 ).transpose (0 , 1 )
1426
+ image_feature = torch .cat ((base_image_feature , image_feature ), dim = 0 )
1427
+ else :
1428
+ image_feature = image_feature [0 ]
1429
+ if image_newline is not None :
1430
+ image_feature = torch .cat ((image_feature , image_newline [None ].to (image_feature )), dim = 0 )
1431
+ new_image_features .append (image_feature )
1432
+ feature_lens .append (image_feature .size (0 ))
1433
+ image_features = torch .cat (new_image_features , dim = 0 )
1434
+ feature_lens = torch .tensor (feature_lens , dtype = torch .long , device = image_features .device )
1435
+ return image_features , feature_lens
1436
+
1363
1437
@staticmethod
1364
1438
def preprocess_inputs (
1365
1439
text : str ,
0 commit comments