1
+ from typing import Optional , Tuple
2
+
3
+ import torch
4
+ from mmcv .cnn import build_norm_layer
5
+ from mmcv .ops import DynamicScatter
6
+ from torch import Tensor , nn
7
+
8
+ from mmdet3d .registry import MODELS
9
+ from mmdet3d .models .voxel_encoders .utils import PFNLayer , get_paddings_indicator
10
+ @MODELS .register_module ()
11
+ class PillarFeatureNetAutoware (nn .Module ):
12
+ """Pillar Feature Net.
13
+
14
+ The network prepares the pillar features and performs forward pass
15
+ through PFNLayers.
16
+
17
+ Args:
18
+ in_channels (int, optional): Number of input features,
19
+ either x, y, z or x, y, z, r. Defaults to 4.
20
+ feat_channels (tuple, optional): Number of features in each of the
21
+ N PFNLayers. Defaults to (64, ).
22
+ with_distance (bool, optional): Whether to include Euclidean distance
23
+ to points. Defaults to False.
24
+ with_cluster_center (bool, optional): [description]. Defaults to True.
25
+ with_voxel_center (bool, optional): [description]. Defaults to True.
26
+ voxel_size (tuple[float], optional): Size of voxels, only utilize x
27
+ and y size. Defaults to (0.2, 0.2, 4).
28
+ point_cloud_range (tuple[float], optional): Point cloud range, only
29
+ utilizes x and y min. Defaults to (0, -40, -3, 70.4, 40, 1).
30
+ norm_cfg ([type], optional): [description].
31
+ Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01).
32
+ mode (str, optional): The mode to gather point features. Options are
33
+ 'max' or 'avg'. Defaults to 'max'.
34
+ legacy (bool, optional): Whether to use the new behavior or
35
+ the original behavior. Defaults to True.
36
+ """
37
+
38
+ def __init__ (self ,
39
+ in_channels : Optional [int ] = 4 ,
40
+ feat_channels : Optional [tuple ] = (64 ,),
41
+ with_distance : Optional [bool ] = False ,
42
+ with_cluster_center : Optional [bool ] = True ,
43
+ with_voxel_center : Optional [bool ] = True ,
44
+ voxel_size : Optional [Tuple [float ]] = (0.2 , 0.2 , 4 ),
45
+ point_cloud_range : Optional [Tuple [float ]] = (0 , - 40 , - 3 , 70.4 ,
46
+ 40 , 1 ),
47
+ norm_cfg : Optional [dict ] = dict (
48
+ type = 'BN1d' , eps = 1e-3 , momentum = 0.01 ),
49
+ mode : Optional [str ] = 'max' ,
50
+ legacy : Optional [bool ] = True ,
51
+ use_voxel_center_z : Optional [bool ] = True , ):
52
+ super (PillarFeatureNetAutoware , self ).__init__ ()
53
+ assert len (feat_channels ) > 0
54
+ self .legacy = legacy
55
+ self .use_voxel_center_z = use_voxel_center_z
56
+ if with_cluster_center :
57
+ in_channels += 3
58
+ if with_voxel_center :
59
+ in_channels += 2
60
+ if self .use_voxel_center_z :
61
+ in_channels += 1
62
+ if with_distance :
63
+ in_channels += 1
64
+ self ._with_distance = with_distance
65
+ self ._with_cluster_center = with_cluster_center
66
+ self ._with_voxel_center = with_voxel_center
67
+ # Create PillarFeatureNet layers
68
+ self .in_channels = in_channels
69
+ feat_channels = [in_channels ] + list (feat_channels )
70
+ pfn_layers = []
71
+ for i in range (len (feat_channels ) - 1 ):
72
+ in_filters = feat_channels [i ]
73
+ out_filters = feat_channels [i + 1 ]
74
+ if i < len (feat_channels ) - 2 :
75
+ last_layer = False
76
+ else :
77
+ last_layer = True
78
+ pfn_layers .append (
79
+ PFNLayer (
80
+ in_filters ,
81
+ out_filters ,
82
+ norm_cfg = norm_cfg ,
83
+ last_layer = last_layer ,
84
+ mode = mode ))
85
+ self .pfn_layers = nn .ModuleList (pfn_layers )
86
+
87
+ # Need pillar (voxel) size and x/y offset in order to calculate offset
88
+ self .vx = voxel_size [0 ]
89
+ self .vy = voxel_size [1 ]
90
+ self .vz = voxel_size [2 ]
91
+ self .x_offset = self .vx / 2 + point_cloud_range [0 ]
92
+ self .y_offset = self .vy / 2 + point_cloud_range [1 ]
93
+ self .z_offset = self .vz / 2 + point_cloud_range [2 ]
94
+ self .point_cloud_range = point_cloud_range
95
+
96
+ def forward (self , features : Tensor , num_points : Tensor , coors : Tensor ,
97
+ * args , ** kwargs ) -> Tensor :
98
+ """Forward function.
99
+
100
+ Args:
101
+ features (torch.Tensor): Point features or raw points in shape
102
+ (N, M, C).
103
+ num_points (torch.Tensor): Number of points in each pillar.
104
+ coors (torch.Tensor): Coordinates of each voxel.
105
+
106
+ Returns:
107
+ torch.Tensor: Features of pillars.
108
+ """
109
+ features_ls = [features ]
110
+ # Find distance of x, y, and z from cluster center
111
+ if self ._with_cluster_center :
112
+ points_mean = features [:, :, :3 ].sum (
113
+ dim = 1 , keepdim = True ) / num_points .type_as (features ).view (
114
+ - 1 , 1 , 1 )
115
+ f_cluster = features [:, :, :3 ] - points_mean
116
+ features_ls .append (f_cluster )
117
+
118
+ # Find distance of x, y, and z from pillar center
119
+ dtype = features .dtype
120
+ if self ._with_voxel_center :
121
+ center_feature_size = 3 if self .use_voxel_center_z else 2
122
+ if not self .legacy :
123
+ f_center = torch .zeros_like (features [:, :, :center_feature_size ])
124
+ f_center [:, :, 0 ] = features [:, :, 0 ] - (
125
+ coors [:, 3 ].to (dtype ).unsqueeze (1 ) * self .vx +
126
+ self .x_offset )
127
+ f_center [:, :, 1 ] = features [:, :, 1 ] - (
128
+ coors [:, 2 ].to (dtype ).unsqueeze (1 ) * self .vy +
129
+ self .y_offset )
130
+ if self .use_voxel_center_z :
131
+ f_center [:, :, 2 ] = features [:, :, 2 ] - (
132
+ coors [:, 1 ].to (dtype ).unsqueeze (1 ) * self .vz +
133
+ self .z_offset )
134
+ else :
135
+ f_center = features [:, :, :center_feature_size ]
136
+ f_center [:, :, 0 ] = f_center [:, :, 0 ] - (
137
+ coors [:, 3 ].type_as (features ).unsqueeze (1 ) * self .vx +
138
+ self .x_offset )
139
+ f_center [:, :, 1 ] = f_center [:, :, 1 ] - (
140
+ coors [:, 2 ].type_as (features ).unsqueeze (1 ) * self .vy +
141
+ self .y_offset )
142
+ if self .use_voxel_center_z :
143
+ f_center [:, :, 2 ] = f_center [:, :, 2 ] - (
144
+ coors [:, 1 ].type_as (features ).unsqueeze (1 ) * self .vz +
145
+ self .z_offset )
146
+ features_ls .append (f_center )
147
+
148
+ if self ._with_distance :
149
+ points_dist = torch .norm (features [:, :, :3 ], 2 , 2 , keepdim = True )
150
+ features_ls .append (points_dist )
151
+
152
+ # Combine together feature decorations
153
+ features = torch .cat (features_ls , dim = - 1 )
154
+ # The feature decorations were calculated without regard to whether
155
+ # pillar was empty. Need to ensure that
156
+ # empty pillars remain set to zeros.
157
+ voxel_count = features .shape [1 ]
158
+ mask = get_paddings_indicator (num_points , voxel_count , axis = 0 )
159
+ mask = torch .unsqueeze (mask , - 1 ).type_as (features )
160
+ features *= mask
161
+
162
+ for pfn in self .pfn_layers :
163
+ features = pfn (features , num_points )
164
+
165
+ return features .squeeze (1 )
0 commit comments