@@ -34,6 +34,9 @@ Tracker::Tracker(
34
34
std::mt19937 gen (std::random_device{}());
35
35
std::independent_bits_engine<std::mt19937, 8 , uint8_t > bit_eng (gen);
36
36
std::generate (uuid_.uuid .begin (), uuid_.uuid .end (), bit_eng);
37
+
38
+ // initialize last_filtered_class_
39
+ last_filtered_class_ = object_recognition_utils::getHighestProbClassification (classification_);
37
40
}
38
41
39
42
bool Tracker::updateWithMeasurement (
@@ -54,6 +57,83 @@ bool Tracker::updateWithoutMeasurement()
54
57
return true ;
55
58
}
56
59
60
+ void Tracker::updateClassification (
61
+ const std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification)
62
+ {
63
+ // Update classification
64
+ // 1. Match classification label
65
+ // 2. Update the matched classification probability with a gain
66
+ // 3. If the label is not found, add it to the classification list
67
+ // 4. If the old class probability is not found, decay the probability
68
+ // 5. Normalize the probability
69
+
70
+ const double gain = 0.05 ;
71
+ const double gain_inv = 1.0 - gain;
72
+ const double decay = gain_inv;
73
+
74
+ for (const auto & new_class : classification) {
75
+ bool found = false ;
76
+ for (auto & old_class : classification_) {
77
+ // Update the matched classification probability with a gain
78
+ if (new_class.label == old_class.label ) {
79
+ old_class.probability = old_class.probability * gain_inv + new_class.probability * gain;
80
+ found = true ;
81
+ break ;
82
+ }
83
+ }
84
+ // If the label is not found, add it to the classification list
85
+ if (!found) {
86
+ classification_.push_back (new_class);
87
+ }
88
+ }
89
+ // If the old class probability is not found, decay the probability
90
+ for (auto & old_class : classification_) {
91
+ bool found = false ;
92
+ for (const auto & new_class : classification) {
93
+ if (new_class.label == old_class.label ) {
94
+ found = true ;
95
+ break ;
96
+ }
97
+ }
98
+ if (!found) {
99
+ old_class.probability *= decay;
100
+ }
101
+ }
102
+
103
+ // Normalize
104
+ double sum = 0.0 ;
105
+ for (const auto & class_ : classification_) {
106
+ sum += class_.probability ;
107
+ }
108
+ for (auto & class_ : classification_) {
109
+ class_.probability /= sum;
110
+ }
111
+
112
+ // If the probability is too small, remove the class
113
+ classification_.erase (
114
+ std::remove_if (
115
+ classification_.begin (), classification_.end (),
116
+ [](const auto & class_) { return class_.probability < 0.001 ; }),
117
+ classification_.end ());
118
+
119
+ // Set the last filtered class
120
+ // if the highest probability class is not overcome a certain hysteresis, the last
121
+ // filtered class stays the same
122
+
123
+ for (const auto & class_ : classification_) {
124
+ if (class_.label == last_filtered_class_.label ) {
125
+ last_filtered_class_.probability = class_.probability ;
126
+ break ;
127
+ }
128
+ }
129
+ const double hysteresis = 0.1 ;
130
+ autoware_auto_perception_msgs::msg::ObjectClassification const new_classification =
131
+ object_recognition_utils::getHighestProbClassification (classification_);
132
+ if (new_classification.probability > last_filtered_class_.probability + hysteresis) {
133
+ last_filtered_class_ = new_classification;
134
+ }
135
+ }
136
+
57
137
geometry_msgs::msg::PoseWithCovariance Tracker::getPoseWithCovariance (
58
138
const rclcpp::Time & time) const
59
139
{
0 commit comments