Skip to content

Commit 9b7275f

Browse files
committed
feat: add unit test
Signed-off-by: Kaan Çolak <kaancolak95@gmail.com>
1 parent b80934c commit 9b7275f

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import pytest
2+
import torch
3+
import logging
4+
5+
from mmdet3d.registry import MODELS
6+
7+
import projects.AutowareCenterPoint.centerpoint.pillar_encoder_autoware
8+
def test_pillar_feature_net_autoware():
9+
10+
use_voxel_center_z = False
11+
if not torch.cuda.is_available():
12+
pytest.skip('test requires GPU and torch+cuda')
13+
pillar_feature_net_autoware_cfg = dict(
14+
type='PillarFeatureNetAutoware',
15+
in_channels=4,
16+
feat_channels=[64],
17+
voxel_size=(0.2, 0.2, 8),
18+
point_cloud_range=(-51.2, -51.2, -5.0, 51.2, 51.2, 3.0),
19+
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
20+
use_voxel_center_z=use_voxel_center_z,
21+
with_distance=False,
22+
)
23+
pillar_feature_net_autoware = MODELS.build(pillar_feature_net_autoware_cfg)
24+
25+
features = torch.rand([97297, 20, 4])
26+
num_voxels = torch.randint(1, 100, [97297])
27+
coors = torch.randint(0, 100, [97297, 4])
28+
29+
features = pillar_feature_net_autoware(features, num_voxels, coors)
30+
31+
if not use_voxel_center_z:
32+
assert pillar_feature_net_autoware.pfn_layers[0].linear.in_features == 9
33+
else:
34+
assert pillar_feature_net_autoware.pfn_layers[0].linear.in_features == 9
35+
36+
assert pillar_feature_net_autoware.pfn_layers[0].linear.out_features == 64
37+
38+
assert features.shape == torch.Size([97297, 64])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import numpy as np
2+
from mmcv.transforms.base import BaseTransform
3+
from mmengine.registry import TRANSFORMS
4+
from mmengine.structures import InstanceData
5+
6+
from mmdet3d.datasets import NuScenesDataset
7+
from mmdet3d.structures import Det3DDataSample, LiDARInstance3DBoxes
8+
9+
from projects.AutowareCenterPoint.datasets.tier4_dataset import T4Dataset
10+
11+
def _generate_t4_dataset_config():
12+
data_root = 'data/sample_dataset/'
13+
ann_file = 'T4Dataset_infos_train.pkl'
14+
classes = [
15+
'car', 'truck', 'bus', 'bicycle', 'pedestrian'
16+
]
17+
18+
if 'Identity' not in TRANSFORMS:
19+
20+
@TRANSFORMS.register_module()
21+
class Identity(BaseTransform):
22+
23+
def transform(self, info):
24+
packed_input = dict(data_samples=Det3DDataSample())
25+
if 'ann_info' in info:
26+
packed_input[
27+
'data_samples'].gt_instances_3d = InstanceData()
28+
packed_input[
29+
'data_samples'].gt_instances_3d.labels_3d = info[
30+
'ann_info']['gt_labels_3d']
31+
return packed_input
32+
33+
pipeline = [
34+
dict(type='Identity'),
35+
]
36+
modality = dict(use_lidar=True, use_camera=True)
37+
data_prefix = dict(
38+
pts='samples/LIDAR_TOP',
39+
img='samples/CAM_BACK_LEFT',
40+
sweeps='sweeps/LIDAR_TOP')
41+
return data_root, ann_file, classes, data_prefix, pipeline, modality
42+
43+
44+
def test_getitem():
45+
np.random.seed(0)
46+
data_root, ann_file, classes, data_prefix, pipeline, modality = \
47+
_generate_t4_dataset_config()
48+
49+
t4_dataset = T4Dataset(
50+
data_root=data_root,
51+
ann_file=ann_file,
52+
data_prefix=data_prefix,
53+
pipeline=pipeline,
54+
metainfo=dict(classes=classes),
55+
modality=modality)
56+
57+
t4_dataset.prepare_data(0)
58+
input_dict = t4_dataset.get_data_info(0)
59+
# assert the the path should contains data_prefix and data_root
60+
assert data_prefix['pts'] in input_dict['lidar_points']['lidar_path']
61+
assert data_root in input_dict['lidar_points']['lidar_path']
62+
63+
for cam_id, img_info in input_dict['images'].items():
64+
if 'img_path' in img_info:
65+
assert data_prefix['img'] in img_info['img_path']
66+
assert data_root in img_info['img_path']
67+
68+
ann_info = t4_dataset.parse_ann_info(input_dict)
69+
70+
# assert the keys in ann_info and the type
71+
assert 'gt_labels_3d' in ann_info
72+
assert ann_info['gt_labels_3d'].dtype == np.int64
73+
assert len(ann_info['gt_labels_3d']) == 70
74+
75+
assert 'gt_bboxes_3d' in ann_info
76+
assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)
77+
78+
assert len(t4_dataset.metainfo['classes']) == 5
79+
assert input_dict['token'] == '5f73a4f0dd74434260bf72821b24c8d4'
80+
assert input_dict['timestamp'] == 1697190328.324525

0 commit comments

Comments
 (0)