|
| 1 | +import json |
| 2 | +import os |
| 3 | +from typing import Dict, List, Optional, Tuple |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from nuscenes.eval.common.data_classes import EvalBox, EvalBoxes |
| 7 | +from nuscenes.eval.common.loaders import load_prediction |
| 8 | +from nuscenes.eval.common.utils import center_distance |
| 9 | +from nuscenes.eval.detection.data_classes import DetectionBox |
| 10 | +from nuscenes.eval.detection.evaluate import DetectionEval as _DetectionEval |
| 11 | + |
| 12 | + |
| 13 | +def toEvalBoxes(nusc_boxes: Dict[str, List[Dict]], box_cls: EvalBox = DetectionBox) -> EvalBoxes: |
| 14 | + """ |
| 15 | +
|
| 16 | + nusc_boxes = { |
| 17 | + "sample_token_1": [ |
| 18 | + { |
| 19 | + "sample_token": str, |
| 20 | + "translation": List[float], (x, y, z) |
| 21 | + "size": List[float], (width, length, height) |
| 22 | + "rotation": List[float], (w, x, y, z) |
| 23 | + "velocity": List[float], (vx, vy) |
| 24 | + "detection_name": str, |
| 25 | + "detection_score": float, |
| 26 | + "attribute_name": str, |
| 27 | + }, |
| 28 | + ... |
| 29 | + ], |
| 30 | + ... |
| 31 | + } |
| 32 | +
|
| 33 | + Args: |
| 34 | + nusc_boxes (Dict[List[Dict]]): [description] |
| 35 | + box_cls (EvalBox, optional): [description]. Defaults to DetectionBox. |
| 36 | +
|
| 37 | + Returns: |
| 38 | + EvalBoxes: [description] |
| 39 | + """ |
| 40 | + return EvalBoxes.deserialize(nusc_boxes, box_cls) |
| 41 | + |
| 42 | + |
| 43 | +class DetectionConfig: |
| 44 | + """Data class that specifies the detection evaluation settings.""" |
| 45 | + |
| 46 | + def __init__( |
| 47 | + self, |
| 48 | + class_names: List[str], |
| 49 | + class_range: Dict[str, int], |
| 50 | + dist_fcn: str, |
| 51 | + dist_ths: List[float], |
| 52 | + dist_th_tp: float, |
| 53 | + min_recall: float, |
| 54 | + min_precision: float, |
| 55 | + max_boxes_per_sample: float, |
| 56 | + mean_ap_weight: int, |
| 57 | + ): |
| 58 | + |
| 59 | + # assert set(class_range.keys()) == set(DETECTION_NAMES), "Class count mismatch." |
| 60 | + assert dist_th_tp in dist_ths, "dist_th_tp must be in set of dist_ths." |
| 61 | + |
| 62 | + self.class_range = class_range |
| 63 | + self.dist_fcn = dist_fcn |
| 64 | + self.dist_ths = dist_ths |
| 65 | + self.dist_th_tp = dist_th_tp |
| 66 | + self.min_recall = min_recall |
| 67 | + self.min_precision = min_precision |
| 68 | + self.max_boxes_per_sample = max_boxes_per_sample |
| 69 | + self.mean_ap_weight = mean_ap_weight |
| 70 | + |
| 71 | + self.class_names = class_names |
| 72 | + |
| 73 | + def __eq__(self, other): |
| 74 | + eq = True |
| 75 | + for key in self.serialize().keys(): |
| 76 | + eq = eq and np.array_equal(getattr(self, key), getattr(other, key)) |
| 77 | + return eq |
| 78 | + |
| 79 | + def serialize(self) -> dict: |
| 80 | + """Serialize instance into json-friendly format.""" |
| 81 | + return { |
| 82 | + "class_names": self.class_names, |
| 83 | + "class_range": self.class_range, |
| 84 | + "dist_fcn": self.dist_fcn, |
| 85 | + "dist_ths": self.dist_ths, |
| 86 | + "dist_th_tp": self.dist_th_tp, |
| 87 | + "min_recall": self.min_recall, |
| 88 | + "min_precision": self.min_precision, |
| 89 | + "max_boxes_per_sample": self.max_boxes_per_sample, |
| 90 | + "mean_ap_weight": self.mean_ap_weight, |
| 91 | + } |
| 92 | + |
| 93 | + @classmethod |
| 94 | + def deserialize(cls, content: dict): |
| 95 | + """Initialize from serialized dictionary.""" |
| 96 | + return cls( |
| 97 | + content["class_names"], |
| 98 | + content["class_range"], |
| 99 | + content["dist_fcn"], |
| 100 | + content["dist_ths"], |
| 101 | + content["dist_th_tp"], |
| 102 | + content["min_recall"], |
| 103 | + content["min_precision"], |
| 104 | + content["max_boxes_per_sample"], |
| 105 | + content["mean_ap_weight"], |
| 106 | + ) |
| 107 | + |
| 108 | + @property |
| 109 | + def dist_fcn_callable(self): |
| 110 | + """Return the distance function corresponding to the dist_fcn string.""" |
| 111 | + if self.dist_fcn == "center_distance": |
| 112 | + return center_distance |
| 113 | + else: |
| 114 | + raise Exception("Error: Unknown distance function %s!" % self.dist_fcn) |
| 115 | + |
| 116 | + |
| 117 | +class nuScenesDetectionEval(_DetectionEval): |
| 118 | + """ |
| 119 | + This is the official nuScenes detection evaluation code. |
| 120 | + Results are written to the provided output_dir. |
| 121 | + nuScenes uses the following detection metrics: |
| 122 | + - Mean Average Precision (mAP): Uses center-distance as matching criterion; averaged over distance thresholds. |
| 123 | + - True Positive (TP) metrics: Average of translation, velocity, scale, orientation and attribute errors. |
| 124 | + - nuScenes Detection Score (NDS): The weighted sum of the above. |
| 125 | + Here is an overview of the functions in this method: |
| 126 | + - init: Loads GT annotations and predictions stored in JSON format and filters the boxes. |
| 127 | + - run: Performs evaluation and dumps the metric data to disk. |
| 128 | + - render: Renders various plots and dumps to disk. |
| 129 | + We assume that: |
| 130 | + - Every sample_token is given in the results, although there may be not predictions for that sample. |
| 131 | + Please see https://www.nuscenes.org/object-detection for more details. |
| 132 | + """ |
| 133 | + |
| 134 | + def __init__( |
| 135 | + self, |
| 136 | + config: DetectionConfig, |
| 137 | + result_boxes: Dict, |
| 138 | + gt_boxes: Dict, |
| 139 | + meta: Dict, |
| 140 | + eval_set: str, |
| 141 | + output_dir: Optional[str] = None, |
| 142 | + verbose: bool = True, |
| 143 | + ): |
| 144 | + """ |
| 145 | + Initialize a DetectionEval object. |
| 146 | + :param config: A DetectionConfig object. |
| 147 | + :param result_boxes: result bounding boxes. |
| 148 | + :param gt_boxes: ground-truth bounding boxes. |
| 149 | + :param eval_set: The dataset split to evaluate on, e.g. train, val or test. |
| 150 | + :param output_dir: Folder to save plots and results to. |
| 151 | + :param verbose: Whether to print to stdout. |
| 152 | + """ |
| 153 | + self.cfg = config |
| 154 | + self.meta = meta |
| 155 | + self.eval_set = eval_set |
| 156 | + self.output_dir = output_dir |
| 157 | + self.verbose = verbose |
| 158 | + |
| 159 | + # Make dirs. |
| 160 | + self.plot_dir = os.path.join(self.output_dir, "plots") |
| 161 | + if not os.path.isdir(self.output_dir): |
| 162 | + os.makedirs(self.output_dir) |
| 163 | + if not os.path.isdir(self.plot_dir): |
| 164 | + os.makedirs(self.plot_dir) |
| 165 | + |
| 166 | + self.pred_boxes: EvalBoxes = toEvalBoxes(result_boxes) |
| 167 | + self.gt_boxes: EvalBoxes = toEvalBoxes(gt_boxes) |
| 168 | + |
| 169 | + assert set(self.pred_boxes.sample_tokens) == set( |
| 170 | + self.gt_boxes.sample_tokens |
| 171 | + ), "Samples in split doesn't match samples in predictions." |
| 172 | + |
| 173 | + self.sample_tokens = self.gt_boxes.sample_tokens |
0 commit comments