@@ -49,6 +49,8 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
49
49
voxelization and dynamic voxelization. Defaults to 'hard'.
50
50
voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer
51
51
config. Defaults to None.
52
+ max_voxels (int): Maximum number of voxels in each voxel grid. Defaults
53
+ to None.
52
54
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
53
55
Defaults to None.
54
56
std (Sequence[Number], optional): The pixel standard deviation of
@@ -77,6 +79,7 @@ def __init__(self,
77
79
voxel : bool = False ,
78
80
voxel_type : str = 'hard' ,
79
81
voxel_layer : OptConfigType = None ,
82
+ max_voxels : Optional [int ] = None ,
80
83
mean : Sequence [Number ] = None ,
81
84
std : Sequence [Number ] = None ,
82
85
pad_size_divisor : int = 1 ,
@@ -103,6 +106,7 @@ def __init__(self,
103
106
batch_augments = batch_augments )
104
107
self .voxel = voxel
105
108
self .voxel_type = voxel_type
109
+ self .max_voxels = max_voxels
106
110
if voxel :
107
111
self .voxel_layer = VoxelizationByGridShape (** voxel_layer )
108
112
@@ -423,20 +427,22 @@ def voxelize(self, points: List[torch.Tensor],
423
427
res_coors -= res_coors .min (0 )[0 ]
424
428
425
429
res_coors_numpy = res_coors .cpu ().numpy ()
426
- inds , voxel2point_map = self .sparse_quantize (
430
+ inds , point2voxel_map = self .sparse_quantize (
427
431
res_coors_numpy , return_index = True , return_inverse = True )
428
- voxel2point_map = torch .from_numpy (voxel2point_map ).cuda ()
429
- if self .training :
430
- if len (inds ) > 80000 :
431
- inds = np .random .choice (inds , 80000 , replace = False )
432
+ point2voxel_map = torch .from_numpy (point2voxel_map ).cuda ()
433
+ if self .training and self .max_voxels is not None :
434
+ if len (inds ) > self .max_voxels :
435
+ inds = np .random .choice (
436
+ inds , self .max_voxels , replace = False )
432
437
inds = torch .from_numpy (inds ).cuda ()
433
- data_sample .gt_pts_seg .voxel_semantic_mask \
434
- = data_sample .gt_pts_seg .pts_semantic_mask [inds ]
438
+ if hasattr (data_sample .gt_pts_seg , 'pts_semantic_mask' ):
439
+ data_sample .gt_pts_seg .voxel_semantic_mask \
440
+ = data_sample .gt_pts_seg .pts_semantic_mask [inds ]
435
441
res_voxel_coors = res_coors [inds ]
436
442
res_voxels = res [inds ]
437
443
res_voxel_coors = F .pad (
438
444
res_voxel_coors , (0 , 1 ), mode = 'constant' , value = i )
439
- data_sample .voxel2point_map = voxel2point_map .long ()
445
+ data_sample .point2voxel_map = point2voxel_map .long ()
440
446
voxels .append (res_voxels )
441
447
coors .append (res_voxel_coors )
442
448
voxels = torch .cat (voxels , dim = 0 )
@@ -466,12 +472,12 @@ def get_voxel_seg(self, res_coors: torch.Tensor, data_sample: SampleList):
466
472
True )
467
473
voxel_semantic_mask = torch .argmax (voxel_semantic_mask , dim = - 1 )
468
474
data_sample .gt_pts_seg .voxel_semantic_mask = voxel_semantic_mask
469
- data_sample .gt_pts_seg . point2voxel_map = point2voxel_map
475
+ data_sample .point2voxel_map = point2voxel_map
470
476
else :
471
477
pseudo_tensor = res_coors .new_ones ([res_coors .shape [0 ], 1 ]).float ()
472
478
_ , _ , point2voxel_map = dynamic_scatter_3d (pseudo_tensor ,
473
479
res_coors , 'mean' , True )
474
- data_sample .gt_pts_seg . point2voxel_map = point2voxel_map
480
+ data_sample .point2voxel_map = point2voxel_map
475
481
476
482
def ravel_hash (self , x : np .ndarray ) -> np .ndarray :
477
483
"""Get voxel coordinates hash for np.unique().
0 commit comments