13
13
# limitations under the License.
14
14
15
15
from dataclasses import dataclass
16
+ import logging
17
+ from pathlib import Path
18
+ import sys
16
19
from typing import Literal
17
20
21
+ from autoware_perception_msgs .msg import TrafficSignalElement
18
22
from perception_eval .evaluation import PerceptionFrameResult
19
23
from pydantic import BaseModel
24
+ from pydantic import field_validator
25
+ import simplejson as json
20
26
21
27
from driving_log_replayer .criteria import PerceptionCriteria
28
+ from driving_log_replayer .perception_eval_conversions import summarize_pass_fail_result
22
29
from driving_log_replayer .result import EvaluationItem
23
30
from driving_log_replayer .result import ResultBase
24
31
from driving_log_replayer .scenario import number
25
32
from driving_log_replayer .scenario import Scenario
26
33
34
+ TRAFFIC_LIGHT_LABEL_MAPPINGS : list [tuple [set , str ]] = [
35
+ ({"green" }, "green" ),
36
+ ({"green" , "straight" }, "green_straight" ),
37
+ ({"green" , "left" }, "green_left" ),
38
+ ({"green" , "right" }, "green_right" ),
39
+ ({"yellow" }, "yellow" ),
40
+ ({"yellow" , "straight" }, "yellow_straight" ),
41
+ ({"yellow" , "left" }, "yellow_left" ),
42
+ ({"yellow" , "right" }, "yellow_right" ),
43
+ ({"yellow" , "straight" , "left" }, "yellow_straight_left" ),
44
+ ({"yellow" , "straight" , "right" }, "yellow_straight_right" ),
45
+ ({"red" }, "red" ),
46
+ ({"red" , "straight" }, "red_straight" ),
47
+ ({"red" , "left" }, "red_left" ),
48
+ ({"red" , "right" }, "red_right" ),
49
+ ({"red" , "straight" , "left" }, "red_straight_left" ),
50
+ ({"red" , "straight" , "right" }, "red_straight_right" ),
51
+ ({"red" , "straight" , "left" , "right" }, "red_straight_left_right" ),
52
+ ({"red" , "right" , "diagonal" }, "red_rightdiagonal" ),
53
+ ({"red" , "left" , "diagonal" }, "red_leftdiagonal" ),
54
+ ]
27
55
28
- class Conditions (BaseModel ):
56
+
57
+ def get_traffic_light_label_str (elements : list [TrafficSignalElement ]) -> str : # noqa
58
+ label_infos = []
59
+ for element in elements :
60
+ if element .shape == TrafficSignalElement .CIRCLE :
61
+ if element .color == TrafficSignalElement .RED :
62
+ label_infos .append ("red" )
63
+ elif element .color == TrafficSignalElement .AMBER :
64
+ label_infos .append ("yellow" )
65
+ elif element .color == TrafficSignalElement .GREEN :
66
+ label_infos .append ("green" )
67
+ continue
68
+
69
+ if element .shape == TrafficSignalElement .UP_ARROW :
70
+ label_infos .append ("straight" )
71
+ elif element .shape == TrafficSignalElement .LEFT_ARROW :
72
+ label_infos .append ("left" )
73
+ elif element .shape == TrafficSignalElement .RIGHT_ARROW :
74
+ label_infos .append ("right" )
75
+ elif element .shape in (
76
+ TrafficSignalElement .UP_LEFT_ARROW ,
77
+ TrafficSignalElement .DOWN_LEFT_ARROW ,
78
+ ):
79
+ label_infos .append ("left" )
80
+ label_infos .append ("diagonal" )
81
+ elif element .shape in (
82
+ TrafficSignalElement .UP_RIGHT_ARROW ,
83
+ TrafficSignalElement .DOWN_RIGHT_ARROW ,
84
+ ):
85
+ label_infos .append ("right" )
86
+ label_infos .append ("diagonal" )
87
+
88
+ label_infos = set (label_infos )
89
+
90
+ for info_set , label in TRAFFIC_LIGHT_LABEL_MAPPINGS :
91
+ if label_infos == info_set :
92
+ return label
93
+
94
+ return "unknown"
95
+
96
+
97
+ def get_most_probable_element (
98
+ elements : list [TrafficSignalElement ],
99
+ ) -> TrafficSignalElement :
100
+ index : int = elements .index (max (elements , key = lambda x : x .confidence ))
101
+ return elements [index ]
102
+
103
+
104
+ class Filter (BaseModel ):
105
+ Distance : tuple [float , float ] | None = None
106
+ # add filter condition here
107
+
108
+ @field_validator ("Distance" , mode = "before" )
109
+ @classmethod
110
+ def validate_distance_range (cls , v : str | None ) -> tuple [number , number ] | None :
111
+ if v is None :
112
+ return None
113
+
114
+ err_msg = f"{ v } is not valid distance range, expected ordering min-max with min < max."
115
+
116
+ s_lower , s_upper = v .split ("-" )
117
+ if s_upper == "" :
118
+ s_upper = sys .float_info .max
119
+
120
+ lower = float (s_lower )
121
+ upper = float (s_upper )
122
+
123
+ if lower >= upper :
124
+ raise ValueError (err_msg )
125
+ return (lower , upper )
126
+
127
+
128
+ class Criteria (BaseModel ):
29
129
PassRate : number
30
- CriteriaMethod : Literal ["num_tp" , "metrics_score" ] | None = None
31
- CriteriaLevel : Literal ["perfect" , "hard" , "normal" , "easy" ] | number | None = None
130
+ CriteriaMethod : (
131
+ Literal ["num_tp" , "label" , "metrics_score" , "metrics_score_maph" ] | list [str ] | None
132
+ ) = None
133
+ CriteriaLevel : (
134
+ Literal ["perfect" , "hard" , "normal" , "easy" ] | list [str ] | number | list [number ] | None
135
+ ) = None
136
+ Filter : Filter
137
+
138
+
139
+ class Conditions (BaseModel ):
140
+ Criterion : list [Criteria ]
32
141
33
142
34
143
class Evaluation (BaseModel ):
35
144
UseCaseName : Literal ["traffic_light" ]
36
- UseCaseFormatVersion : Literal ["0.2.0" , "0.3 .0" ]
145
+ UseCaseFormatVersion : Literal ["1.0 .0" ]
37
146
Datasets : list [dict ]
38
147
Conditions : Conditions
39
148
PerceptionEvaluationConfig : dict
@@ -45,6 +154,36 @@ class TrafficLightScenario(Scenario):
45
154
Evaluation : Evaluation
46
155
47
156
157
+ class FailResultHolder :
158
+ def __init__ (self , save_dir : str ) -> None :
159
+ self .save_path : str = Path (save_dir , "fail_info.json" )
160
+ self .buffer = []
161
+
162
+ def add_frame (self , frame_result : PerceptionFrameResult ) -> None :
163
+ if frame_result .pass_fail_result .get_fail_object_num () <= 0 :
164
+ return
165
+ info = {"fp" : [], "fn" : []}
166
+ info ["timestamp" ] = frame_result .frame_ground_truth .unix_time
167
+ for fp_result in frame_result .pass_fail_result .fp_object_results :
168
+ est_label = fp_result .estimated_object .semantic_label .label .value
169
+ gt_label = (
170
+ fp_result .ground_truth_object .semantic_label .label .value
171
+ if fp_result .ground_truth_object is not None
172
+ else None
173
+ )
174
+ info ["fp" ].append ({"est" : est_label , "gt" : gt_label })
175
+ for fn_object in frame_result .pass_fail_result .fn_objects :
176
+ info ["fn" ].append ({"est" : None , "gt" : fn_object .semantic_label .label .value })
177
+
178
+ info_str = f"Fail timestamp: { info } "
179
+ logging .info (info_str )
180
+ self .buffer .append (info )
181
+
182
+ def save (self ) -> None :
183
+ with self .save_path .open ("w" ) as f :
184
+ json .dump (self .buffer , f , ensure_ascii = False , indent = 4 )
185
+
186
+
48
187
@dataclass
49
188
class Perception (EvaluationItem ):
50
189
success : bool = True
@@ -55,11 +194,12 @@ def __post_init__(self) -> None:
55
194
self .criteria : PerceptionCriteria = PerceptionCriteria (
56
195
methods = self .condition .CriteriaMethod ,
57
196
levels = self .condition .CriteriaLevel ,
197
+ distance_range = self .condition .Filter .Distance ,
58
198
)
59
199
60
200
def set_frame (self , frame : PerceptionFrameResult ) -> dict :
61
201
frame_success = "Fail"
62
- result , _ = self .criteria .get_result (frame )
202
+ result , ret_frame = self .criteria .get_result (frame )
63
203
64
204
if result is None :
65
205
self .no_gt_no_obj += 1
@@ -76,28 +216,30 @@ def set_frame(self, frame: PerceptionFrameResult) -> dict:
76
216
return {
77
217
"PassFail" : {
78
218
"Result" : {"Total" : self .success_str (), "Frame" : frame_success },
79
- "Info" : {
80
- "TP" : len (frame .pass_fail_result .tp_object_results ),
81
- "FP" : len (frame .pass_fail_result .fp_object_results ),
82
- "FN" : len (frame .pass_fail_result .fn_objects ),
83
- },
219
+ "Info" : summarize_pass_fail_result (ret_frame .pass_fail_result ),
84
220
},
85
221
}
86
222
87
223
88
224
class TrafficLightResult (ResultBase ):
89
225
def __init__ (self , condition : Conditions ) -> None :
90
226
super ().__init__ ()
91
- self .__perception = Perception (condition = condition )
227
+ self .__perception_criterion : list [Perception ] = []
228
+ for i , criteria in enumerate (condition .Criterion ):
229
+ self .__perception_criterion .append (
230
+ Perception (name = f"criteria{ i } " , condition = criteria ),
231
+ )
92
232
93
233
def update (self ) -> None :
94
- summary_str = f"{ self .__perception .summary } "
95
- if self .__perception .success :
96
- self ._success = True
97
- self ._summary = f"Passed: { summary_str } "
98
- else :
99
- self ._success = False
100
- self ._summary = f"Failed: { summary_str } "
234
+ all_summary : list [str ] = []
235
+ all_success : list [bool ] = []
236
+ for criterion in self .__perception_criterion :
237
+ tmp_success = criterion .success
238
+ prefix_str = "Passed: " if tmp_success else "Failed: "
239
+ all_summary .append (prefix_str + criterion .summary )
240
+ all_success .append (tmp_success )
241
+ self ._summary = ", " .join (all_summary )
242
+ self ._success = all (all_success )
101
243
102
244
def set_frame (
103
245
self ,
@@ -110,7 +252,8 @@ def set_frame(
110
252
"FrameName" : frame .frame_name ,
111
253
"FrameSkip" : skip ,
112
254
}
113
- self ._frame |= self .__perception .set_frame (frame )
255
+ for criterion in self .__perception_criterion :
256
+ self ._frame [criterion .name ] = criterion .set_frame (frame )
114
257
self .update ()
115
258
116
259
def set_final_metrics (self , final_metrics : dict ) -> None :
0 commit comments