Skip to content

Commit 5bd5365

Browse files
committed
add TIERIV dataset reader and custom dataset nuscenes eval metric
Signed-off-by: Kaan Çolak <kaancolak95@gmail.com>
1 parent be583bc commit 5bd5365

File tree

12 files changed

+2938
-2
lines changed

12 files changed

+2938
-2
lines changed

configs/centerpoint/centerpoint_custom_test.py

+602
Large diffs are not rendered by default.

mmdet3d/datasets/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
RandomShiftScale, Resize3D, VoxelBasedPointSampler)
2525
from .utils import get_loading_pipeline
2626
from .waymo_dataset import WaymoDataset
27+
from .tier4_dataset import Tier4Dataset
2728

2829
__all__ = [
29-
'KittiDataset', 'CBGSDataset', 'NuScenesDataset', 'LyftDataset',
30+
'KittiDataset', 'CBGSDataset', 'NuScenesDataset', 'LyftDataset', 'Tier4Dataset',
3031
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
3132
'PointShuffle', 'ObjectRangeFilter', 'PointsRangeFilter',
3233
'LoadPointsFromFile', 'S3DISSegDataset', 'S3DISDataset',

mmdet3d/datasets/tier4_dataset.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from os import path as osp
3+
import os
4+
from typing import Callable, List, Union
5+
6+
import numpy as np
7+
8+
from mmdet3d.registry import DATASETS
9+
from mmdet3d.structures import LiDARInstance3DBoxes
10+
from mmdet3d.structures.bbox_3d.cam_box3d import CameraInstance3DBoxes
11+
# from .det3d_dataset import Det3DDataset
12+
from .nuscenes_dataset import NuScenesDataset
13+
14+
15+
@DATASETS.register_module()
16+
class Tier4Dataset(NuScenesDataset):
17+
METAINFO = {
18+
'classes': ('car', 'truck', 'bus', 'bicycle', 'pedestrian'),
19+
'version': 'v1.0-trainval',
20+
'palette': [
21+
(255, 158, 0), # Orange
22+
(255, 99, 71), # Tomato
23+
(255, 140, 0), # Darkorange
24+
(255, 127, 80), # Coral
25+
(233, 150, 70), # Darksalmon
26+
]
27+
}
28+
29+
def __init__(self,
30+
box_type_3d: str = 'LiDAR',
31+
load_type: str = 'frame_based',
32+
with_velocity: bool = True,
33+
use_valid_flag: bool = False,
34+
**kwargs,) -> None:
35+
36+
self.use_valid_flag = use_valid_flag
37+
self.with_velocity = with_velocity
38+
39+
# TODO: Redesign multi-view data process in the future
40+
assert load_type in ('frame_based', 'mv_image_based',
41+
'fov_image_based')
42+
self.load_type = load_type
43+
44+
assert box_type_3d.lower() in ('lidar', 'camera')
45+
super().__init__(**kwargs)
46+
47+
def parse_data_info(self, info: dict) -> dict:
48+
"""Process the raw data info.
49+
50+
Convert all relative path of needed modality data file to
51+
the absolute path. And process the `instances` field to
52+
`ann_info` in training stage.
53+
54+
Args:
55+
info (dict): Raw info dict.
56+
57+
Returns:
58+
dict: Has `ann_info` in training stage. And
59+
all path has been converted to absolute path.
60+
"""
61+
if self.load_type == 'mv_image_based':
62+
info = super().parse_data_info(info)
63+
else:
64+
if self.modality['use_lidar']:
65+
info['lidar_points']['lidar_path'] = \
66+
osp.join(
67+
self.data_prefix.get('pts', ''),
68+
info['lidar_points']['lidar_path'])
69+
70+
info['num_pts_feats'] = info['lidar_points']['num_pts_feats']
71+
info['lidar_path'] = info['lidar_points']['lidar_path']
72+
if 'lidar_sweeps' in info:
73+
for sweep in info['lidar_sweeps']:
74+
file_suffix_splitted = sweep['lidar_points']['lidar_path'].split(os.sep)
75+
file_suffix = os.sep.join(file_suffix_splitted[-4:])
76+
if 'samples' in sweep['lidar_points']['lidar_path']:
77+
sweep['lidar_points']['lidar_path'] = osp.join(
78+
self.data_prefix['pts'], file_suffix)
79+
else:
80+
sweep['lidar_points']['lidar_path'] = info['lidar_points']['lidar_path']
81+
82+
if self.modality['use_camera']:
83+
for cam_id, img_info in info['images'].items():
84+
if 'img_path' in img_info:
85+
if cam_id in self.data_prefix:
86+
cam_prefix = self.data_prefix[cam_id]
87+
else:
88+
cam_prefix = self.data_prefix.get('img', '')
89+
img_info['img_path'] = osp.join(cam_prefix,
90+
img_info['img_path'])
91+
if self.default_cam_key is not None:
92+
info['img_path'] = info['images'][
93+
self.default_cam_key]['img_path']
94+
if 'lidar2cam' in info['images'][self.default_cam_key]:
95+
info['lidar2cam'] = np.array(
96+
info['images'][self.default_cam_key]['lidar2cam'])
97+
if 'cam2img' in info['images'][self.default_cam_key]:
98+
info['cam2img'] = np.array(
99+
info['images'][self.default_cam_key]['cam2img'])
100+
if 'lidar2img' in info['images'][self.default_cam_key]:
101+
info['lidar2img'] = np.array(
102+
info['images'][self.default_cam_key]['lidar2img'])
103+
else:
104+
info['lidar2img'] = info['cam2img'] @ info['lidar2cam']
105+
106+
if not self.test_mode:
107+
# used in training
108+
info['ann_info'] = self.parse_ann_info(info)
109+
if self.test_mode and self.load_eval_anns:
110+
info['eval_ann_info'] = self.parse_ann_info(info)
111+
112+
return info
113+
114+
115+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from .eval import DetectionConfig, nuScenesDetectionEval
2+
from .utils import (
3+
class_mapping_kitti2nuscenes,
4+
format_nuscenes_metrics,
5+
format_nuscenes_metrics_table,
6+
transform_det_annos_to_nusc_annos,
7+
)
8+
9+
__all__ = [
10+
"DetectionConfig",
11+
"nuScenesDetectionEval",
12+
"class_mapping_kitti2nuscenes",
13+
"format_nuscenes_metrics_table",
14+
"format_nuscenes_metrics",
15+
"transform_det_annos_to_nusc_annos",
16+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import json
2+
import os
3+
from typing import Dict, List, Optional, Tuple
4+
5+
import numpy as np
6+
from nuscenes.eval.common.data_classes import EvalBox, EvalBoxes
7+
from nuscenes.eval.common.loaders import load_prediction
8+
from nuscenes.eval.common.utils import center_distance
9+
from nuscenes.eval.detection.data_classes import DetectionBox
10+
from nuscenes.eval.detection.evaluate import DetectionEval as _DetectionEval
11+
12+
13+
def toEvalBoxes(nusc_boxes: Dict[str, List[Dict]], box_cls: EvalBox = DetectionBox) -> EvalBoxes:
14+
"""
15+
16+
nusc_boxes = {
17+
"sample_token_1": [
18+
{
19+
"sample_token": str,
20+
"translation": List[float], (x, y, z)
21+
"size": List[float], (width, length, height)
22+
"rotation": List[float], (w, x, y, z)
23+
"velocity": List[float], (vx, vy)
24+
"detection_name": str,
25+
"detection_score": float,
26+
"attribute_name": str,
27+
},
28+
...
29+
],
30+
...
31+
}
32+
33+
Args:
34+
nusc_boxes (Dict[List[Dict]]): [description]
35+
box_cls (EvalBox, optional): [description]. Defaults to DetectionBox.
36+
37+
Returns:
38+
EvalBoxes: [description]
39+
"""
40+
return EvalBoxes.deserialize(nusc_boxes, box_cls)
41+
42+
43+
class DetectionConfig:
44+
"""Data class that specifies the detection evaluation settings."""
45+
46+
def __init__(
47+
self,
48+
class_names: List[str],
49+
class_range: Dict[str, int],
50+
dist_fcn: str,
51+
dist_ths: List[float],
52+
dist_th_tp: float,
53+
min_recall: float,
54+
min_precision: float,
55+
max_boxes_per_sample: float,
56+
mean_ap_weight: int,
57+
):
58+
59+
# assert set(class_range.keys()) == set(DETECTION_NAMES), "Class count mismatch."
60+
assert dist_th_tp in dist_ths, "dist_th_tp must be in set of dist_ths."
61+
62+
self.class_range = class_range
63+
self.dist_fcn = dist_fcn
64+
self.dist_ths = dist_ths
65+
self.dist_th_tp = dist_th_tp
66+
self.min_recall = min_recall
67+
self.min_precision = min_precision
68+
self.max_boxes_per_sample = max_boxes_per_sample
69+
self.mean_ap_weight = mean_ap_weight
70+
71+
self.class_names = class_names
72+
73+
def __eq__(self, other):
74+
eq = True
75+
for key in self.serialize().keys():
76+
eq = eq and np.array_equal(getattr(self, key), getattr(other, key))
77+
return eq
78+
79+
def serialize(self) -> dict:
80+
"""Serialize instance into json-friendly format."""
81+
return {
82+
"class_names": self.class_names,
83+
"class_range": self.class_range,
84+
"dist_fcn": self.dist_fcn,
85+
"dist_ths": self.dist_ths,
86+
"dist_th_tp": self.dist_th_tp,
87+
"min_recall": self.min_recall,
88+
"min_precision": self.min_precision,
89+
"max_boxes_per_sample": self.max_boxes_per_sample,
90+
"mean_ap_weight": self.mean_ap_weight,
91+
}
92+
93+
@classmethod
94+
def deserialize(cls, content: dict):
95+
"""Initialize from serialized dictionary."""
96+
return cls(
97+
content["class_names"],
98+
content["class_range"],
99+
content["dist_fcn"],
100+
content["dist_ths"],
101+
content["dist_th_tp"],
102+
content["min_recall"],
103+
content["min_precision"],
104+
content["max_boxes_per_sample"],
105+
content["mean_ap_weight"],
106+
)
107+
108+
@property
109+
def dist_fcn_callable(self):
110+
"""Return the distance function corresponding to the dist_fcn string."""
111+
if self.dist_fcn == "center_distance":
112+
return center_distance
113+
else:
114+
raise Exception("Error: Unknown distance function %s!" % self.dist_fcn)
115+
116+
117+
class nuScenesDetectionEval(_DetectionEval):
118+
"""
119+
This is the official nuScenes detection evaluation code.
120+
Results are written to the provided output_dir.
121+
nuScenes uses the following detection metrics:
122+
- Mean Average Precision (mAP): Uses center-distance as matching criterion; averaged over distance thresholds.
123+
- True Positive (TP) metrics: Average of translation, velocity, scale, orientation and attribute errors.
124+
- nuScenes Detection Score (NDS): The weighted sum of the above.
125+
Here is an overview of the functions in this method:
126+
- init: Loads GT annotations and predictions stored in JSON format and filters the boxes.
127+
- run: Performs evaluation and dumps the metric data to disk.
128+
- render: Renders various plots and dumps to disk.
129+
We assume that:
130+
- Every sample_token is given in the results, although there may be not predictions for that sample.
131+
Please see https://www.nuscenes.org/object-detection for more details.
132+
"""
133+
134+
def __init__(
135+
self,
136+
config: DetectionConfig,
137+
result_boxes: Dict,
138+
gt_boxes: Dict,
139+
meta: Dict,
140+
eval_set: str,
141+
output_dir: Optional[str] = None,
142+
verbose: bool = True,
143+
):
144+
"""
145+
Initialize a DetectionEval object.
146+
:param config: A DetectionConfig object.
147+
:param result_boxes: result bounding boxes.
148+
:param gt_boxes: ground-truth bounding boxes.
149+
:param eval_set: The dataset split to evaluate on, e.g. train, val or test.
150+
:param output_dir: Folder to save plots and results to.
151+
:param verbose: Whether to print to stdout.
152+
"""
153+
self.cfg = config
154+
self.meta = meta
155+
self.eval_set = eval_set
156+
self.output_dir = output_dir
157+
self.verbose = verbose
158+
159+
# Make dirs.
160+
self.plot_dir = os.path.join(self.output_dir, "plots")
161+
if not os.path.isdir(self.output_dir):
162+
os.makedirs(self.output_dir)
163+
if not os.path.isdir(self.plot_dir):
164+
os.makedirs(self.plot_dir)
165+
166+
self.pred_boxes: EvalBoxes = toEvalBoxes(result_boxes)
167+
self.gt_boxes: EvalBoxes = toEvalBoxes(gt_boxes)
168+
169+
assert set(self.pred_boxes.sample_tokens) == set(
170+
self.gt_boxes.sample_tokens
171+
), "Samples in split doesn't match samples in predictions."
172+
173+
self.sample_tokens = self.gt_boxes.sample_tokens

0 commit comments

Comments
 (0)