Skip to content

Commit 3fd3469

Browse files
committed
feat: object class filter
1 parent fcebec6 commit 3fd3469

File tree

3 files changed

+88
-1
lines changed

3 files changed

+88
-1
lines changed

perception/multi_object_tracker/include/multi_object_tracker/tracker/model/tracker_base.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class Tracker
4242
{
4343
classification_ = classification;
4444
}
45+
void updateClassification(
46+
const std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification);
4547

4648
private:
4749
unique_identifier_msgs::msg::UUID uuid_;
@@ -51,6 +53,9 @@ class Tracker
5153
int total_measurement_count_;
5254
rclcpp::Time last_update_with_measurement_time_;
5355

56+
public:
57+
autoware_auto_perception_msgs::msg::ObjectClassification last_filtered_class_;
58+
5459
public:
5560
Tracker(
5661
const rclcpp::Time & time,
@@ -68,6 +73,7 @@ class Tracker
6873
{
6974
return object_recognition_utils::getHighestProbLabel(classification_);
7075
}
76+
std::uint8_t getFilteredLabel() const { return last_filtered_class_.label; }
7177
int getNoMeasurementCount() const { return no_measurement_count_; }
7278
int getTotalNoMeasurementCount() const { return total_no_measurement_count_; }
7379
int getTotalMeasurementCount() const { return total_measurement_count_; }

perception/multi_object_tracker/src/tracker/model/pedestrian_and_bicycle_tracker.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ bool PedestrianAndBicycleTracker::measure(
4343
pedestrian_tracker_.measure(object, time, self_transform);
4444
bicycle_tracker_.measure(object, time, self_transform);
4545
if (object_recognition_utils::getHighestProbLabel(object.classification) != Label::UNKNOWN)
46-
setClassification(object.classification);
46+
// setClassification(object.classification);
47+
updateClassification(object.classification);
4748
return true;
4849
}
4950

perception/multi_object_tracker/src/tracker/model/tracker_base.cpp

+80
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ Tracker::Tracker(
3434
std::mt19937 gen(std::random_device{}());
3535
std::independent_bits_engine<std::mt19937, 8, uint8_t> bit_eng(gen);
3636
std::generate(uuid_.uuid.begin(), uuid_.uuid.end(), bit_eng);
37+
38+
// initialize last_filtered_class_
39+
last_filtered_class_ = object_recognition_utils::getHighestProbClassification(classification_);
3740
}
3841

3942
bool Tracker::updateWithMeasurement(
@@ -54,6 +57,83 @@ bool Tracker::updateWithoutMeasurement()
5457
return true;
5558
}
5659

60+
void Tracker::updateClassification(
61+
const std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification)
62+
{
63+
// Update classification
64+
// 1. Match classification label
65+
// 2. Update the matched classification probability with a gain
66+
// 3. If the label is not found, add it to the classification list
67+
// 4. If the old class probability is not found, decay the probability
68+
// 5. Normalize the probability
69+
70+
const double gain = 0.05;
71+
const double gain_inv = 1.0 - gain;
72+
const double decay = gain_inv;
73+
74+
for (const auto & new_class : classification) {
75+
bool found = false;
76+
for (auto & old_class : classification_) {
77+
// Update the matched classification probability with a gain
78+
if (new_class.label == old_class.label) {
79+
old_class.probability = old_class.probability * gain_inv + new_class.probability * gain;
80+
found = true;
81+
break;
82+
}
83+
}
84+
// If the label is not found, add it to the classification list
85+
if (!found) {
86+
classification_.push_back(new_class);
87+
}
88+
}
89+
// If the old class probability is not found, decay the probability
90+
for (auto & old_class : classification_) {
91+
bool found = false;
92+
for (const auto & new_class : classification) {
93+
if (new_class.label == old_class.label) {
94+
found = true;
95+
break;
96+
}
97+
}
98+
if (!found) {
99+
old_class.probability *= decay;
100+
}
101+
}
102+
103+
// Normalize
104+
double sum = 0.0;
105+
for (const auto & class_ : classification_) {
106+
sum += class_.probability;
107+
}
108+
for (auto & class_ : classification_) {
109+
class_.probability /= sum;
110+
}
111+
112+
// If the probability is too small, remove the class
113+
classification_.erase(
114+
std::remove_if(
115+
classification_.begin(), classification_.end(),
116+
[](const auto & class_) { return class_.probability < 0.001; }),
117+
classification_.end());
118+
119+
// Set the last filtered class
120+
// if the highest probability class is not overcome a certain hysteresis, the last
121+
// filtered class stays the same
122+
123+
for (const auto & class_ : classification_) {
124+
if (class_.label == last_filtered_class_.label) {
125+
last_filtered_class_.probability = class_.probability;
126+
break;
127+
}
128+
}
129+
const double hysteresis = 0.1;
130+
autoware_auto_perception_msgs::msg::ObjectClassification const new_classification =
131+
object_recognition_utils::getHighestProbClassification(classification_);
132+
if (new_classification.probability > last_filtered_class_.probability + hysteresis) {
133+
last_filtered_class_ = new_classification;
134+
}
135+
}
136+
57137
geometry_msgs::msg::PoseWithCovariance Tracker::getPoseWithCovariance(
58138
const rclcpp::Time & time) const
59139
{

0 commit comments

Comments
 (0)