Skip to content

Commit d167d59

Browse files
committed
feat: add test, docker file and readme
Signed-off-by: Kaan Colak <kaancolak95@gmail.com>
1 parent 90d15e9 commit d167d59

File tree

4 files changed

+133
-12
lines changed

4 files changed

+133
-12
lines changed

docker/Dockerfile

+4-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
ARG PYTORCH="1.9.0"
2-
ARG CUDA="11.1"
1+
ARG PYTORCH="1.13.1"
2+
ARG CUDA="11.6"
33
ARG CUDNN="8"
44

55
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
@@ -9,14 +9,6 @@ ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" \
99
CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \
1010
FORCE_CUDA="1"
1111

12-
# Avoid Public GPG key error
13-
# https://github.com/NVIDIA/nvidia-docker/issues/1631
14-
RUN rm /etc/apt/sources.list.d/cuda.list \
15-
&& rm /etc/apt/sources.list.d/nvidia-ml.list \
16-
&& apt-key del 7fa2af80 \
17-
&& apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub \
18-
&& apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
19-
2012
# (Optional, use Mirror to speed up downloads)
2113
# RUN sed -i 's/http:\/\/archive.ubuntu.com\/ubuntu\//http:\/\/mirrors.aliyun.com\/ubuntu\//g' /etc/apt/sources.list && \
2214
# pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
@@ -29,11 +21,11 @@ RUN apt-get update \
2921

3022
# Install MMEngine, MMCV and MMDetection
3123
RUN pip install openmim && \
32-
mim install "mmengine" "mmcv>=2.0.0rc4" "mmdet>=3.0.0"
24+
mim install "mmengine" "mmcv>=2.0.0rc4" "mmdet>=3.0.0rc5, <3.3.0"
3325

3426
# Install MMDetection3D
3527
RUN conda clean --all \
36-
&& git clone https://github.com/open-mmlab/mmdetection3d.git -b dev-1.x /mmdetection3d \
28+
&& git clone https://github.com/autowarefoundation/mmdetection3d.git -b main /mmdetection3d \
3729
&& cd /mmdetection3d \
3830
&& pip install --no-cache-dir -e .
3931

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
## Introduction
2+
3+
The **[mmdetection3d](https://github.com/open-mmlab/mmdetection3d)** repository includes an additional voxel encoder
4+
feature for the CenterPoint 3D object detection model, known as voxel center z,
5+
not originally used in the **[main implementation](https://github.com/tianweiy/CenterPoint)**,
6+
Autoware maintains consistency with the input size of the original implementation. Consequently,
7+
to ensure integration with Autoware's lidar centerpoint package, we have forked the original repository and made
8+
the requisite code modifications.
9+
10+
To train custom CenterPoint models and convert them into ONNX format for deployment in Autoware, please refer to the instructions provided in the README.md file included with
11+
Autoware's **[lidar_centerpoint](https://autowarefoundation.github.io/autoware.universe/main/perception/lidar_centerpoint/)** package. These instructions will provide a step-by-step guide for training the CenterPoint model.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pytest
2+
import torch
3+
4+
from mmdet3d.registry import MODELS
5+
from projects.AutowareCenterPoint.centerpoint.pillar_encoder_autoware import \
6+
PillarFeatureNetAutoware # noqa: F401
7+
8+
9+
def test_pillar_feature_net_autoware():
10+
11+
use_voxel_center_z = False
12+
if not torch.cuda.is_available():
13+
pytest.skip('test requires GPU and torch+cuda')
14+
pillar_feature_net_autoware_cfg = dict(
15+
type='PillarFeatureNetAutoware',
16+
in_channels=4,
17+
feat_channels=[64],
18+
voxel_size=(0.2, 0.2, 8),
19+
point_cloud_range=(-51.2, -51.2, -5.0, 51.2, 51.2, 3.0),
20+
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
21+
use_voxel_center_z=use_voxel_center_z,
22+
with_distance=False,
23+
)
24+
pillar_feature_net_autoware = MODELS.build(pillar_feature_net_autoware_cfg)
25+
26+
features = torch.rand([97297, 20, 4])
27+
num_voxels = torch.randint(1, 100, [97297])
28+
coors = torch.randint(0, 100, [97297, 4])
29+
30+
features = pillar_feature_net_autoware(features, num_voxels, coors)
31+
32+
if not use_voxel_center_z:
33+
assert pillar_feature_net_autoware.pfn_layers[
34+
0].linear.in_features == 9
35+
else:
36+
assert pillar_feature_net_autoware.pfn_layers[
37+
0].linear.in_features == 9
38+
39+
assert pillar_feature_net_autoware.pfn_layers[0].linear.out_features == 64
40+
41+
assert features.shape == torch.Size([97297, 64])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import numpy as np
2+
from mmcv.transforms.base import BaseTransform
3+
from mmengine.registry import TRANSFORMS
4+
from mmengine.structures import InstanceData
5+
6+
from mmdet3d.structures import Det3DDataSample, LiDARInstance3DBoxes
7+
from projects.AutowareCenterPoint.datasets.tier4_dataset import T4Dataset
8+
9+
10+
def _generate_t4_dataset_config():
11+
data_root = 'data/sample_dataset/'
12+
ann_file = 'T4Dataset_infos_train.pkl'
13+
classes = ['car', 'truck', 'bus', 'bicycle', 'pedestrian']
14+
15+
if 'Identity' not in TRANSFORMS:
16+
17+
@TRANSFORMS.register_module()
18+
class Identity(BaseTransform):
19+
20+
def transform(self, info):
21+
packed_input = dict(data_samples=Det3DDataSample())
22+
if 'ann_info' in info:
23+
packed_input[
24+
'data_samples'].gt_instances_3d = InstanceData()
25+
packed_input[
26+
'data_samples'].gt_instances_3d.labels_3d = info[
27+
'ann_info']['gt_labels_3d']
28+
return packed_input
29+
30+
pipeline = [
31+
dict(type='Identity'),
32+
]
33+
modality = dict(use_lidar=True, use_camera=True)
34+
data_prefix = dict(
35+
pts='samples/LIDAR_TOP',
36+
img='samples/CAM_BACK_LEFT',
37+
sweeps='sweeps/LIDAR_TOP')
38+
return data_root, ann_file, classes, data_prefix, pipeline, modality
39+
40+
41+
def test_getitem():
42+
np.random.seed(0)
43+
data_root, ann_file, classes, data_prefix, pipeline, modality = \
44+
_generate_t4_dataset_config()
45+
46+
t4_dataset = T4Dataset(
47+
data_root=data_root,
48+
ann_file=ann_file,
49+
data_prefix=data_prefix,
50+
pipeline=pipeline,
51+
metainfo=dict(classes=classes),
52+
modality=modality)
53+
54+
t4_dataset.prepare_data(0)
55+
input_dict = t4_dataset.get_data_info(0)
56+
# assert the the path should contains data_prefix and data_root
57+
assert data_prefix['pts'] in input_dict['lidar_points']['lidar_path']
58+
assert data_root in input_dict['lidar_points']['lidar_path']
59+
60+
for cam_id, img_info in input_dict['images'].items():
61+
if 'img_path' in img_info:
62+
assert data_prefix['img'] in img_info['img_path']
63+
assert data_root in img_info['img_path']
64+
65+
ann_info = t4_dataset.parse_ann_info(input_dict)
66+
67+
# assert the keys in ann_info and the type
68+
assert 'gt_labels_3d' in ann_info
69+
assert ann_info['gt_labels_3d'].dtype == np.int64
70+
assert len(ann_info['gt_labels_3d']) == 70
71+
72+
assert 'gt_bboxes_3d' in ann_info
73+
assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)
74+
75+
assert len(t4_dataset.metainfo['classes']) == 5
76+
assert input_dict['token'] == '5f73a4f0dd74434260bf72821b24c8d4'
77+
assert input_dict['timestamp'] == 1697190328.324525

0 commit comments

Comments
 (0)