Skip to content

Commit b89c542

Browse files
authored
Add Autoware compatibility (#1)
Signed-off-by: Kaan Çolak <kaancolak95@gmail.com>
1 parent 5c0613b commit b89c542

21 files changed

+4421
-12
lines changed

docker/Dockerfile

+4-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
ARG PYTORCH="1.9.0"
2-
ARG CUDA="11.1"
1+
ARG PYTORCH="1.13.1"
2+
ARG CUDA="11.6"
33
ARG CUDNN="8"
44

55
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
@@ -9,14 +9,6 @@ ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" \
99
CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \
1010
FORCE_CUDA="1"
1111

12-
# Avoid Public GPG key error
13-
# https://github.com/NVIDIA/nvidia-docker/issues/1631
14-
RUN rm /etc/apt/sources.list.d/cuda.list \
15-
&& rm /etc/apt/sources.list.d/nvidia-ml.list \
16-
&& apt-key del 7fa2af80 \
17-
&& apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub \
18-
&& apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
19-
2012
# (Optional, use Mirror to speed up downloads)
2113
# RUN sed -i 's/http:\/\/archive.ubuntu.com\/ubuntu\//http:\/\/mirrors.aliyun.com\/ubuntu\//g' /etc/apt/sources.list && \
2214
# pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
@@ -29,11 +21,11 @@ RUN apt-get update \
2921

3022
# Install MMEngine, MMCV and MMDetection
3123
RUN pip install openmim && \
32-
mim install "mmengine" "mmcv>=2.0.0rc4" "mmdet>=3.0.0"
24+
mim install "mmengine" "mmcv>=2.0.0rc4" "mmdet>=3.0.0rc5, <3.3.0"
3325

3426
# Install MMDetection3D
3527
RUN conda clean --all \
36-
&& git clone https://github.com/open-mmlab/mmdetection3d.git -b dev-1.x /mmdetection3d \
28+
&& git clone https://github.com/autowarefoundation/mmdetection3d.git -b main /mmdetection3d \
3729
&& cd /mmdetection3d \
3830
&& pip install --no-cache-dir -e .
3931

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
## Introduction
2+
3+
The **[mmdetection3d](https://github.com/open-mmlab/mmdetection3d)** repository includes an additional voxel encoder
4+
feature for the CenterPoint 3D object detection model, known as voxel center z,
5+
not originally used in the **[main implementation](https://github.com/tianweiy/CenterPoint)**,
6+
Autoware maintains consistency with the input size of the original implementation. Consequently,
7+
to ensure integration with Autoware's lidar centerpoint package, we have forked the original repository and made
8+
the requisite code modifications.
9+
10+
To train custom CenterPoint models and convert them into ONNX format for deployment in Autoware, please refer to the instructions provided in the README.md file included with
11+
Autoware's **[lidar_centerpoint](https://autowarefoundation.github.io/autoware.universe/main/perception/lidar_centerpoint/)** package. These instructions will provide a step-by-step guide for training the CenterPoint model.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .pillar_encoder_autoware import PillarFeatureNetAutoware
2+
3+
__all__ = ['PillarFeatureNetAutoware']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from typing import Optional, Tuple
2+
3+
import torch
4+
from torch import Tensor, nn
5+
6+
from mmdet3d.models.voxel_encoders.utils import (PFNLayer,
7+
get_paddings_indicator)
8+
from mmdet3d.registry import MODELS
9+
10+
11+
@MODELS.register_module()
12+
class PillarFeatureNetAutoware(nn.Module):
13+
"""Pillar Feature Net.
14+
15+
The network prepares the pillar features and performs forward pass
16+
through PFNLayers.
17+
18+
Args:
19+
in_channels (int, optional): Number of input features,
20+
either x, y, z or x, y, z, r. Defaults to 4.
21+
feat_channels (tuple, optional): Number of features in each of the
22+
N PFNLayers. Defaults to (64, ).
23+
with_distance (bool, optional): Whether to include Euclidean distance
24+
to points. Defaults to False.
25+
with_cluster_center (bool, optional): [description]. Defaults to True.
26+
with_voxel_center (bool, optional): [description]. Defaults to True.
27+
voxel_size (tuple[float], optional): Size of voxels, only utilize x
28+
and y size. Defaults to (0.2, 0.2, 4).
29+
point_cloud_range (tuple[float], optional): Point cloud range, only
30+
utilizes x and y min. Defaults to (0, -40, -3, 70.4, 40, 1).
31+
norm_cfg ([type], optional): [description].
32+
Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01).
33+
mode (str, optional): The mode to gather point features. Options are
34+
'max' or 'avg'. Defaults to 'max'.
35+
legacy (bool, optional): Whether to use the new behavior or
36+
the original behavior. Defaults to True.
37+
"""
38+
39+
def __init__(
40+
self,
41+
in_channels: Optional[int] = 4,
42+
feat_channels: Optional[tuple] = (64, ),
43+
with_distance: Optional[bool] = False,
44+
with_cluster_center: Optional[bool] = True,
45+
with_voxel_center: Optional[bool] = True,
46+
voxel_size: Optional[Tuple[float]] = (0.2, 0.2, 4),
47+
point_cloud_range: Optional[Tuple[float]] = (0, -40, -3, 70.4, 40, 1),
48+
norm_cfg: Optional[dict] = dict(type='BN1d', eps=1e-3, momentum=0.01),
49+
mode: Optional[str] = 'max',
50+
legacy: Optional[bool] = True,
51+
use_voxel_center_z: Optional[bool] = True,
52+
):
53+
super(PillarFeatureNetAutoware, self).__init__()
54+
assert len(feat_channels) > 0
55+
self.legacy = legacy
56+
self.use_voxel_center_z = use_voxel_center_z
57+
if with_cluster_center:
58+
in_channels += 3
59+
if with_voxel_center:
60+
in_channels += 2
61+
if self.use_voxel_center_z:
62+
in_channels += 1
63+
if with_distance:
64+
in_channels += 1
65+
self._with_distance = with_distance
66+
self._with_cluster_center = with_cluster_center
67+
self._with_voxel_center = with_voxel_center
68+
# Create PillarFeatureNet layers
69+
self.in_channels = in_channels
70+
feat_channels = [in_channels] + list(feat_channels)
71+
pfn_layers = []
72+
for i in range(len(feat_channels) - 1):
73+
in_filters = feat_channels[i]
74+
out_filters = feat_channels[i + 1]
75+
if i < len(feat_channels) - 2:
76+
last_layer = False
77+
else:
78+
last_layer = True
79+
pfn_layers.append(
80+
PFNLayer(
81+
in_filters,
82+
out_filters,
83+
norm_cfg=norm_cfg,
84+
last_layer=last_layer,
85+
mode=mode))
86+
self.pfn_layers = nn.ModuleList(pfn_layers)
87+
88+
# Need pillar (voxel) size and x/y offset in order to calculate offset
89+
self.vx = voxel_size[0]
90+
self.vy = voxel_size[1]
91+
self.vz = voxel_size[2]
92+
self.x_offset = self.vx / 2 + point_cloud_range[0]
93+
self.y_offset = self.vy / 2 + point_cloud_range[1]
94+
self.z_offset = self.vz / 2 + point_cloud_range[2]
95+
self.point_cloud_range = point_cloud_range
96+
97+
def forward(self, features: Tensor, num_points: Tensor, coors: Tensor,
98+
*args, **kwargs) -> Tensor:
99+
"""Forward function.
100+
101+
Args:
102+
features (torch.Tensor): Point features or raw points in shape
103+
(N, M, C).
104+
num_points (torch.Tensor): Number of points in each pillar.
105+
coors (torch.Tensor): Coordinates of each voxel.
106+
107+
Returns:
108+
torch.Tensor: Features of pillars.
109+
"""
110+
features_ls = [features]
111+
# Find distance of x, y, and z from cluster center
112+
if self._with_cluster_center:
113+
points_mean = features[:, :, :3].sum(
114+
dim=1, keepdim=True) / num_points.type_as(features).view(
115+
-1, 1, 1)
116+
f_cluster = features[:, :, :3] - points_mean
117+
features_ls.append(f_cluster)
118+
119+
# Find distance of x, y, and z from pillar center
120+
dtype = features.dtype
121+
if self._with_voxel_center:
122+
center_feature_size = 3 if self.use_voxel_center_z else 2
123+
if not self.legacy:
124+
f_center = torch.zeros_like(
125+
features[:, :, :center_feature_size])
126+
f_center[:, :, 0] = features[:, :, 0] - (
127+
coors[:, 3].to(dtype).unsqueeze(1) * self.vx +
128+
self.x_offset)
129+
f_center[:, :, 1] = features[:, :, 1] - (
130+
coors[:, 2].to(dtype).unsqueeze(1) * self.vy +
131+
self.y_offset)
132+
if self.use_voxel_center_z:
133+
f_center[:, :, 2] = features[:, :, 2] - (
134+
coors[:, 1].to(dtype).unsqueeze(1) * self.vz +
135+
self.z_offset)
136+
else:
137+
f_center = features[:, :, :center_feature_size]
138+
f_center[:, :, 0] = f_center[:, :, 0] - (
139+
coors[:, 3].type_as(features).unsqueeze(1) * self.vx +
140+
self.x_offset)
141+
f_center[:, :, 1] = f_center[:, :, 1] - (
142+
coors[:, 2].type_as(features).unsqueeze(1) * self.vy +
143+
self.y_offset)
144+
if self.use_voxel_center_z:
145+
f_center[:, :, 2] = f_center[:, :, 2] - (
146+
coors[:, 1].type_as(features).unsqueeze(1) * self.vz +
147+
self.z_offset)
148+
features_ls.append(f_center)
149+
150+
if self._with_distance:
151+
points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True)
152+
features_ls.append(points_dist)
153+
154+
# Combine together feature decorations
155+
features = torch.cat(features_ls, dim=-1)
156+
# The feature decorations were calculated without regard to whether
157+
# pillar was empty. Need to ensure that
158+
# empty pillars remain set to zeros.
159+
voxel_count = features.shape[1]
160+
mask = get_paddings_indicator(num_points, voxel_count, axis=0)
161+
mask = torch.unsqueeze(mask, -1).type_as(features)
162+
features *= mask
163+
164+
for pfn in self.pfn_layers:
165+
features = pfn(features, num_points)
166+
167+
return features.squeeze(1)

0 commit comments

Comments
 (0)