|
| 1 | +from typing import Optional, Tuple |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import Tensor, nn |
| 5 | + |
| 6 | +from mmdet3d.models.voxel_encoders.utils import (PFNLayer, |
| 7 | + get_paddings_indicator) |
| 8 | +from mmdet3d.registry import MODELS |
| 9 | + |
| 10 | + |
| 11 | +@MODELS.register_module() |
| 12 | +class PillarFeatureNetAutoware(nn.Module): |
| 13 | + """Pillar Feature Net. |
| 14 | +
|
| 15 | + The network prepares the pillar features and performs forward pass |
| 16 | + through PFNLayers. |
| 17 | +
|
| 18 | + Args: |
| 19 | + in_channels (int, optional): Number of input features, |
| 20 | + either x, y, z or x, y, z, r. Defaults to 4. |
| 21 | + feat_channels (tuple, optional): Number of features in each of the |
| 22 | + N PFNLayers. Defaults to (64, ). |
| 23 | + with_distance (bool, optional): Whether to include Euclidean distance |
| 24 | + to points. Defaults to False. |
| 25 | + with_cluster_center (bool, optional): [description]. Defaults to True. |
| 26 | + with_voxel_center (bool, optional): [description]. Defaults to True. |
| 27 | + voxel_size (tuple[float], optional): Size of voxels, only utilize x |
| 28 | + and y size. Defaults to (0.2, 0.2, 4). |
| 29 | + point_cloud_range (tuple[float], optional): Point cloud range, only |
| 30 | + utilizes x and y min. Defaults to (0, -40, -3, 70.4, 40, 1). |
| 31 | + norm_cfg ([type], optional): [description]. |
| 32 | + Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01). |
| 33 | + mode (str, optional): The mode to gather point features. Options are |
| 34 | + 'max' or 'avg'. Defaults to 'max'. |
| 35 | + legacy (bool, optional): Whether to use the new behavior or |
| 36 | + the original behavior. Defaults to True. |
| 37 | + """ |
| 38 | + |
| 39 | + def __init__( |
| 40 | + self, |
| 41 | + in_channels: Optional[int] = 4, |
| 42 | + feat_channels: Optional[tuple] = (64, ), |
| 43 | + with_distance: Optional[bool] = False, |
| 44 | + with_cluster_center: Optional[bool] = True, |
| 45 | + with_voxel_center: Optional[bool] = True, |
| 46 | + voxel_size: Optional[Tuple[float]] = (0.2, 0.2, 4), |
| 47 | + point_cloud_range: Optional[Tuple[float]] = (0, -40, -3, 70.4, 40, 1), |
| 48 | + norm_cfg: Optional[dict] = dict(type='BN1d', eps=1e-3, momentum=0.01), |
| 49 | + mode: Optional[str] = 'max', |
| 50 | + legacy: Optional[bool] = True, |
| 51 | + use_voxel_center_z: Optional[bool] = True, |
| 52 | + ): |
| 53 | + super(PillarFeatureNetAutoware, self).__init__() |
| 54 | + assert len(feat_channels) > 0 |
| 55 | + self.legacy = legacy |
| 56 | + self.use_voxel_center_z = use_voxel_center_z |
| 57 | + if with_cluster_center: |
| 58 | + in_channels += 3 |
| 59 | + if with_voxel_center: |
| 60 | + in_channels += 2 |
| 61 | + if self.use_voxel_center_z: |
| 62 | + in_channels += 1 |
| 63 | + if with_distance: |
| 64 | + in_channels += 1 |
| 65 | + self._with_distance = with_distance |
| 66 | + self._with_cluster_center = with_cluster_center |
| 67 | + self._with_voxel_center = with_voxel_center |
| 68 | + # Create PillarFeatureNet layers |
| 69 | + self.in_channels = in_channels |
| 70 | + feat_channels = [in_channels] + list(feat_channels) |
| 71 | + pfn_layers = [] |
| 72 | + for i in range(len(feat_channels) - 1): |
| 73 | + in_filters = feat_channels[i] |
| 74 | + out_filters = feat_channels[i + 1] |
| 75 | + if i < len(feat_channels) - 2: |
| 76 | + last_layer = False |
| 77 | + else: |
| 78 | + last_layer = True |
| 79 | + pfn_layers.append( |
| 80 | + PFNLayer( |
| 81 | + in_filters, |
| 82 | + out_filters, |
| 83 | + norm_cfg=norm_cfg, |
| 84 | + last_layer=last_layer, |
| 85 | + mode=mode)) |
| 86 | + self.pfn_layers = nn.ModuleList(pfn_layers) |
| 87 | + |
| 88 | + # Need pillar (voxel) size and x/y offset in order to calculate offset |
| 89 | + self.vx = voxel_size[0] |
| 90 | + self.vy = voxel_size[1] |
| 91 | + self.vz = voxel_size[2] |
| 92 | + self.x_offset = self.vx / 2 + point_cloud_range[0] |
| 93 | + self.y_offset = self.vy / 2 + point_cloud_range[1] |
| 94 | + self.z_offset = self.vz / 2 + point_cloud_range[2] |
| 95 | + self.point_cloud_range = point_cloud_range |
| 96 | + |
| 97 | + def forward(self, features: Tensor, num_points: Tensor, coors: Tensor, |
| 98 | + *args, **kwargs) -> Tensor: |
| 99 | + """Forward function. |
| 100 | +
|
| 101 | + Args: |
| 102 | + features (torch.Tensor): Point features or raw points in shape |
| 103 | + (N, M, C). |
| 104 | + num_points (torch.Tensor): Number of points in each pillar. |
| 105 | + coors (torch.Tensor): Coordinates of each voxel. |
| 106 | +
|
| 107 | + Returns: |
| 108 | + torch.Tensor: Features of pillars. |
| 109 | + """ |
| 110 | + features_ls = [features] |
| 111 | + # Find distance of x, y, and z from cluster center |
| 112 | + if self._with_cluster_center: |
| 113 | + points_mean = features[:, :, :3].sum( |
| 114 | + dim=1, keepdim=True) / num_points.type_as(features).view( |
| 115 | + -1, 1, 1) |
| 116 | + f_cluster = features[:, :, :3] - points_mean |
| 117 | + features_ls.append(f_cluster) |
| 118 | + |
| 119 | + # Find distance of x, y, and z from pillar center |
| 120 | + dtype = features.dtype |
| 121 | + if self._with_voxel_center: |
| 122 | + center_feature_size = 3 if self.use_voxel_center_z else 2 |
| 123 | + if not self.legacy: |
| 124 | + f_center = torch.zeros_like( |
| 125 | + features[:, :, :center_feature_size]) |
| 126 | + f_center[:, :, 0] = features[:, :, 0] - ( |
| 127 | + coors[:, 3].to(dtype).unsqueeze(1) * self.vx + |
| 128 | + self.x_offset) |
| 129 | + f_center[:, :, 1] = features[:, :, 1] - ( |
| 130 | + coors[:, 2].to(dtype).unsqueeze(1) * self.vy + |
| 131 | + self.y_offset) |
| 132 | + if self.use_voxel_center_z: |
| 133 | + f_center[:, :, 2] = features[:, :, 2] - ( |
| 134 | + coors[:, 1].to(dtype).unsqueeze(1) * self.vz + |
| 135 | + self.z_offset) |
| 136 | + else: |
| 137 | + f_center = features[:, :, :center_feature_size] |
| 138 | + f_center[:, :, 0] = f_center[:, :, 0] - ( |
| 139 | + coors[:, 3].type_as(features).unsqueeze(1) * self.vx + |
| 140 | + self.x_offset) |
| 141 | + f_center[:, :, 1] = f_center[:, :, 1] - ( |
| 142 | + coors[:, 2].type_as(features).unsqueeze(1) * self.vy + |
| 143 | + self.y_offset) |
| 144 | + if self.use_voxel_center_z: |
| 145 | + f_center[:, :, 2] = f_center[:, :, 2] - ( |
| 146 | + coors[:, 1].type_as(features).unsqueeze(1) * self.vz + |
| 147 | + self.z_offset) |
| 148 | + features_ls.append(f_center) |
| 149 | + |
| 150 | + if self._with_distance: |
| 151 | + points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True) |
| 152 | + features_ls.append(points_dist) |
| 153 | + |
| 154 | + # Combine together feature decorations |
| 155 | + features = torch.cat(features_ls, dim=-1) |
| 156 | + # The feature decorations were calculated without regard to whether |
| 157 | + # pillar was empty. Need to ensure that |
| 158 | + # empty pillars remain set to zeros. |
| 159 | + voxel_count = features.shape[1] |
| 160 | + mask = get_paddings_indicator(num_points, voxel_count, axis=0) |
| 161 | + mask = torch.unsqueeze(mask, -1).type_as(features) |
| 162 | + features *= mask |
| 163 | + |
| 164 | + for pfn in self.pfn_layers: |
| 165 | + features = pfn(features, num_points) |
| 166 | + |
| 167 | + return features.squeeze(1) |
0 commit comments