@@ -40,7 +40,7 @@ class PillarFeatureNet(nn.Module):
40
40
41
41
def __init__ (self ,
42
42
in_channels : Optional [int ] = 4 ,
43
- feat_channels : Optional [tuple ] = (64 , ),
43
+ feat_channels : Optional [tuple ] = (64 ,),
44
44
with_distance : Optional [bool ] = False ,
45
45
with_cluster_center : Optional [bool ] = True ,
46
46
with_voxel_center : Optional [bool ] = True ,
@@ -50,14 +50,18 @@ def __init__(self,
50
50
norm_cfg : Optional [dict ] = dict (
51
51
type = 'BN1d' , eps = 1e-3 , momentum = 0.01 ),
52
52
mode : Optional [str ] = 'max' ,
53
- legacy : Optional [bool ] = True ):
53
+ legacy : Optional [bool ] = True ,
54
+ use_voxel_center_z : Optional [bool ] = True , ):
54
55
super (PillarFeatureNet , self ).__init__ ()
55
56
assert len (feat_channels ) > 0
56
57
self .legacy = legacy
58
+ self .use_voxel_center_z = use_voxel_center_z
57
59
if with_cluster_center :
58
60
in_channels += 3
59
61
if with_voxel_center :
60
- in_channels += 3
62
+ in_channels += 2
63
+ if self .use_voxel_center_z :
64
+ in_channels += 1
61
65
if with_distance :
62
66
in_channels += 1
63
67
self ._with_distance = with_distance
@@ -110,35 +114,38 @@ def forward(self, features: Tensor, num_points: Tensor, coors: Tensor,
110
114
if self ._with_cluster_center :
111
115
points_mean = features [:, :, :3 ].sum (
112
116
dim = 1 , keepdim = True ) / num_points .type_as (features ).view (
113
- - 1 , 1 , 1 )
117
+ - 1 , 1 , 1 )
114
118
f_cluster = features [:, :, :3 ] - points_mean
115
119
features_ls .append (f_cluster )
116
120
117
121
# Find distance of x, y, and z from pillar center
118
122
dtype = features .dtype
119
123
if self ._with_voxel_center :
124
+ center_feature_size = 3 if self .use_voxel_center_z else 2
120
125
if not self .legacy :
121
- f_center = torch .zeros_like (features [:, :, :3 ])
126
+ f_center = torch .zeros_like (features [:, :, :center_feature_size ])
122
127
f_center [:, :, 0 ] = features [:, :, 0 ] - (
123
- coors [:, 3 ].to (dtype ).unsqueeze (1 ) * self .vx +
124
- self .x_offset )
128
+ coors [:, 3 ].to (dtype ).unsqueeze (1 ) * self .vx +
129
+ self .x_offset )
125
130
f_center [:, :, 1 ] = features [:, :, 1 ] - (
126
- coors [:, 2 ].to (dtype ).unsqueeze (1 ) * self .vy +
127
- self .y_offset )
128
- f_center [:, :, 2 ] = features [:, :, 2 ] - (
129
- coors [:, 1 ].to (dtype ).unsqueeze (1 ) * self .vz +
130
- self .z_offset )
131
+ coors [:, 2 ].to (dtype ).unsqueeze (1 ) * self .vy +
132
+ self .y_offset )
133
+ if self .use_voxel_center_z :
134
+ f_center [:, :, 2 ] = features [:, :, 2 ] - (
135
+ coors [:, 1 ].to (dtype ).unsqueeze (1 ) * self .vz +
136
+ self .z_offset )
131
137
else :
132
- f_center = features [:, :, :3 ]
138
+ f_center = features [:, :, :center_feature_size ]
133
139
f_center [:, :, 0 ] = f_center [:, :, 0 ] - (
134
- coors [:, 3 ].type_as (features ).unsqueeze (1 ) * self .vx +
135
- self .x_offset )
140
+ coors [:, 3 ].type_as (features ).unsqueeze (1 ) * self .vx +
141
+ self .x_offset )
136
142
f_center [:, :, 1 ] = f_center [:, :, 1 ] - (
137
- coors [:, 2 ].type_as (features ).unsqueeze (1 ) * self .vy +
138
- self .y_offset )
139
- f_center [:, :, 2 ] = f_center [:, :, 2 ] - (
140
- coors [:, 1 ].type_as (features ).unsqueeze (1 ) * self .vz +
141
- self .z_offset )
143
+ coors [:, 2 ].type_as (features ).unsqueeze (1 ) * self .vy +
144
+ self .y_offset )
145
+ if self .use_voxel_center_z :
146
+ f_center [:, :, 2 ] = f_center [:, :, 2 ] - (
147
+ coors [:, 1 ].type_as (features ).unsqueeze (1 ) * self .vz +
148
+ self .z_offset )
142
149
features_ls .append (f_center )
143
150
144
151
if self ._with_distance :
@@ -193,7 +200,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
193
200
194
201
def __init__ (self ,
195
202
in_channels : Optional [int ] = 4 ,
196
- feat_channels : Optional [tuple ] = (64 , ),
203
+ feat_channels : Optional [tuple ] = (64 ,),
197
204
with_distance : Optional [bool ] = False ,
198
205
with_cluster_center : Optional [bool ] = True ,
199
206
with_voxel_center : Optional [bool ] = True ,
@@ -264,15 +271,15 @@ def map_voxel_center_to_point(self, pts_coors: Tensor, voxel_mean: Tensor,
264
271
canvas = voxel_mean .new_zeros (canvas_channel , canvas_len )
265
272
# Only include non-empty pillars
266
273
indices = (
267
- voxel_coors [:, 0 ] * canvas_y * canvas_x +
268
- voxel_coors [:, 2 ] * canvas_x + voxel_coors [:, 3 ])
274
+ voxel_coors [:, 0 ] * canvas_y * canvas_x +
275
+ voxel_coors [:, 2 ] * canvas_x + voxel_coors [:, 3 ])
269
276
# Scatter the blob back to the canvas
270
277
canvas [:, indices .long ()] = voxel_mean .t ()
271
278
272
279
# Step 2: get voxel mean for each point
273
280
voxel_index = (
274
- pts_coors [:, 0 ] * canvas_y * canvas_x +
275
- pts_coors [:, 2 ] * canvas_x + pts_coors [:, 3 ])
281
+ pts_coors [:, 0 ] * canvas_y * canvas_x +
282
+ pts_coors [:, 2 ] * canvas_x + pts_coors [:, 3 ])
276
283
center_per_point = canvas [:, voxel_index .long ()].t ()
277
284
return center_per_point
278
285
@@ -301,11 +308,11 @@ def forward(self, features: Tensor, coors: Tensor) -> Tensor:
301
308
if self ._with_voxel_center :
302
309
f_center = features .new_zeros (size = (features .size (0 ), 3 ))
303
310
f_center [:, 0 ] = features [:, 0 ] - (
304
- coors [:, 3 ].type_as (features ) * self .vx + self .x_offset )
311
+ coors [:, 3 ].type_as (features ) * self .vx + self .x_offset )
305
312
f_center [:, 1 ] = features [:, 1 ] - (
306
- coors [:, 2 ].type_as (features ) * self .vy + self .y_offset )
313
+ coors [:, 2 ].type_as (features ) * self .vy + self .y_offset )
307
314
f_center [:, 2 ] = features [:, 2 ] - (
308
- coors [:, 1 ].type_as (features ) * self .vz + self .z_offset )
315
+ coors [:, 1 ].type_as (features ) * self .vz + self .z_offset )
309
316
features_ls .append (f_center )
310
317
311
318
if self ._with_distance :
@@ -324,3 +331,4 @@ def forward(self, features: Tensor, coors: Tensor) -> Tensor:
324
331
features = torch .cat ([point_feats , feat_per_point ], dim = 1 )
325
332
326
333
return voxel_feats , voxel_coors
334
+
0 commit comments