|
| 1 | +import numpy as np |
| 2 | +from mmcv.transforms.base import BaseTransform |
| 3 | +from mmengine.registry import TRANSFORMS |
| 4 | +from mmengine.structures import InstanceData |
| 5 | + |
| 6 | +from mmdet3d.datasets import NuScenesDataset |
| 7 | +from mmdet3d.structures import Det3DDataSample, LiDARInstance3DBoxes |
| 8 | + |
| 9 | +from projects.AutowareCenterPoint.datasets.tier4_dataset import T4Dataset |
| 10 | + |
| 11 | +def _generate_t4_dataset_config(): |
| 12 | + data_root = 'data/sample_dataset/' |
| 13 | + ann_file = 'T4Dataset_infos_train.pkl' |
| 14 | + classes = [ |
| 15 | + 'car', 'truck', 'bus', 'bicycle', 'pedestrian' |
| 16 | + ] |
| 17 | + |
| 18 | + if 'Identity' not in TRANSFORMS: |
| 19 | + |
| 20 | + @TRANSFORMS.register_module() |
| 21 | + class Identity(BaseTransform): |
| 22 | + |
| 23 | + def transform(self, info): |
| 24 | + packed_input = dict(data_samples=Det3DDataSample()) |
| 25 | + if 'ann_info' in info: |
| 26 | + packed_input[ |
| 27 | + 'data_samples'].gt_instances_3d = InstanceData() |
| 28 | + packed_input[ |
| 29 | + 'data_samples'].gt_instances_3d.labels_3d = info[ |
| 30 | + 'ann_info']['gt_labels_3d'] |
| 31 | + return packed_input |
| 32 | + |
| 33 | + pipeline = [ |
| 34 | + dict(type='Identity'), |
| 35 | + ] |
| 36 | + modality = dict(use_lidar=True, use_camera=True) |
| 37 | + data_prefix = dict( |
| 38 | + pts='samples/LIDAR_TOP', |
| 39 | + img='samples/CAM_BACK_LEFT', |
| 40 | + sweeps='sweeps/LIDAR_TOP') |
| 41 | + return data_root, ann_file, classes, data_prefix, pipeline, modality |
| 42 | + |
| 43 | + |
| 44 | +def test_getitem(): |
| 45 | + np.random.seed(0) |
| 46 | + data_root, ann_file, classes, data_prefix, pipeline, modality = \ |
| 47 | + _generate_t4_dataset_config() |
| 48 | + |
| 49 | + t4_dataset = T4Dataset( |
| 50 | + data_root=data_root, |
| 51 | + ann_file=ann_file, |
| 52 | + data_prefix=data_prefix, |
| 53 | + pipeline=pipeline, |
| 54 | + metainfo=dict(classes=classes), |
| 55 | + modality=modality) |
| 56 | + |
| 57 | + t4_dataset.prepare_data(0) |
| 58 | + input_dict = t4_dataset.get_data_info(0) |
| 59 | + # assert the the path should contains data_prefix and data_root |
| 60 | + assert data_prefix['pts'] in input_dict['lidar_points']['lidar_path'] |
| 61 | + assert data_root in input_dict['lidar_points']['lidar_path'] |
| 62 | + |
| 63 | + for cam_id, img_info in input_dict['images'].items(): |
| 64 | + if 'img_path' in img_info: |
| 65 | + assert data_prefix['img'] in img_info['img_path'] |
| 66 | + assert data_root in img_info['img_path'] |
| 67 | + |
| 68 | + ann_info = t4_dataset.parse_ann_info(input_dict) |
| 69 | + |
| 70 | + # assert the keys in ann_info and the type |
| 71 | + assert 'gt_labels_3d' in ann_info |
| 72 | + assert ann_info['gt_labels_3d'].dtype == np.int64 |
| 73 | + assert len(ann_info['gt_labels_3d']) == 70 |
| 74 | + |
| 75 | + assert 'gt_bboxes_3d' in ann_info |
| 76 | + assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes) |
| 77 | + |
| 78 | + assert len(t4_dataset.metainfo['classes']) == 5 |
| 79 | + assert input_dict['token'] == '5f73a4f0dd74434260bf72821b24c8d4' |
| 80 | + assert input_dict['timestamp'] == 1697190328.324525 |
0 commit comments