@@ -58,30 +58,36 @@ void Tracker::updateClassification(
58
58
const std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification)
59
59
{
60
60
// classification algorithm:
61
- // 0. Remove the class with probability < 0.005
62
- // 1. Decay all existing probability
63
- // 2. Update the matched classification probability with a gain
64
- // 3. If the label is not found, add it to the classification list
65
- // 4 . Normalize
61
+ // 0. Normalize the input classification
62
+ // 1-1. Update the matched classification probability with a gain (ratio of 0.05)
63
+ // 1- 2. If the label is not found, add it to the classification list
64
+ // 2. Remove the class with probability < remove_threshold (0.001)
65
+ // 3 . Normalize tracking classification
66
66
67
- // Gain and decay
67
+ // Parameters
68
+ // if the remove_threshold is too high (compare to the gain), the classification will be removed
69
+ // immediately
68
70
const double gain = 0.05 ;
69
- const double decay = 1.0 - gain ;
71
+ constexpr double remove_threshold = 0.001 ;
70
72
71
- // If the probability is less than 0.005, remove the class
72
- classification_.erase (
73
- std::remove_if (
74
- classification_.begin (), classification_.end (),
75
- [](const auto & class_) { return class_.probability < 0.005 ; }),
76
- classification_.end ());
73
+ // Normalization function
74
+ auto normalizeProbabilities =
75
+ [](std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification) {
76
+ double sum = 0.0 ;
77
+ for (const auto & class_ : classification) {
78
+ sum += class_.probability ;
79
+ }
80
+ for (auto & class_ : classification) {
81
+ class_.probability /= sum;
82
+ }
83
+ };
77
84
78
- // Decay all existing probability
79
- for (auto & class_ : classification_) {
80
- class_.probability *= decay;
81
- }
85
+ // Normalize the input
86
+ auto classification_input = classification;
87
+ normalizeProbabilities (classification_input);
82
88
83
89
// Update the matched classification probability with a gain
84
- for (const auto & new_class : classification ) {
90
+ for (const auto & new_class : classification_input ) {
85
91
bool found = false ;
86
92
for (auto & old_class : classification_) {
87
93
if (new_class.label == old_class.label ) {
@@ -98,14 +104,15 @@ void Tracker::updateClassification(
98
104
}
99
105
}
100
106
101
- // Normalize
102
- double sum = 0.0 ;
103
- for (const auto & class_ : classification_) {
104
- sum += class_.probability ;
105
- }
106
- for (auto & class_ : classification_) {
107
- class_.probability /= sum;
108
- }
107
+ // If the probability is less than the threshold, remove the class
108
+ classification_.erase (
109
+ std::remove_if (
110
+ classification_.begin (), classification_.end (),
111
+ [remove_threshold](const auto & class_) { return class_.probability < remove_threshold; }),
112
+ classification_.end ());
113
+
114
+ // Normalize tracking classification
115
+ normalizeProbabilities (classification_);
109
116
}
110
117
111
118
geometry_msgs::msg::PoseWithCovariance Tracker::getPoseWithCovariance (
0 commit comments