Skip to content

Commit 6c00b63

Browse files
committed
fix: revise the filtering process flow
Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp>
1 parent 9ab9c30 commit 6c00b63

File tree

1 file changed

+33
-26
lines changed

1 file changed

+33
-26
lines changed

perception/multi_object_tracker/src/tracker/model/tracker_base.cpp

+33-26
Original file line numberDiff line numberDiff line change
@@ -58,30 +58,36 @@ void Tracker::updateClassification(
5858
const std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification)
5959
{
6060
// classification algorithm:
61-
// 0. Remove the class with probability < 0.005
62-
// 1. Decay all existing probability
63-
// 2. Update the matched classification probability with a gain
64-
// 3. If the label is not found, add it to the classification list
65-
// 4. Normalize
61+
// 0. Normalize the input classification
62+
// 1-1. Update the matched classification probability with a gain (ratio of 0.05)
63+
// 1-2. If the label is not found, add it to the classification list
64+
// 2. Remove the class with probability < remove_threshold (0.001)
65+
// 3. Normalize tracking classification
6666

67-
// Gain and decay
67+
// Parameters
68+
// if the remove_threshold is too high (compare to the gain), the classification will be removed
69+
// immediately
6870
const double gain = 0.05;
69-
const double decay = 1.0 - gain;
71+
constexpr double remove_threshold = 0.001;
7072

71-
// If the probability is less than 0.005, remove the class
72-
classification_.erase(
73-
std::remove_if(
74-
classification_.begin(), classification_.end(),
75-
[](const auto & class_) { return class_.probability < 0.005; }),
76-
classification_.end());
73+
// Normalization function
74+
auto normalizeProbabilities =
75+
[](std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification) {
76+
double sum = 0.0;
77+
for (const auto & class_ : classification) {
78+
sum += class_.probability;
79+
}
80+
for (auto & class_ : classification) {
81+
class_.probability /= sum;
82+
}
83+
};
7784

78-
// Decay all existing probability
79-
for (auto & class_ : classification_) {
80-
class_.probability *= decay;
81-
}
85+
// Normalize the input
86+
auto classification_input = classification;
87+
normalizeProbabilities(classification_input);
8288

8389
// Update the matched classification probability with a gain
84-
for (const auto & new_class : classification) {
90+
for (const auto & new_class : classification_input) {
8591
bool found = false;
8692
for (auto & old_class : classification_) {
8793
if (new_class.label == old_class.label) {
@@ -98,14 +104,15 @@ void Tracker::updateClassification(
98104
}
99105
}
100106

101-
// Normalize
102-
double sum = 0.0;
103-
for (const auto & class_ : classification_) {
104-
sum += class_.probability;
105-
}
106-
for (auto & class_ : classification_) {
107-
class_.probability /= sum;
108-
}
107+
// If the probability is less than the threshold, remove the class
108+
classification_.erase(
109+
std::remove_if(
110+
classification_.begin(), classification_.end(),
111+
[remove_threshold](const auto & class_) { return class_.probability < remove_threshold; }),
112+
classification_.end());
113+
114+
// Normalize tracking classification
115+
normalizeProbabilities(classification_);
109116
}
110117

111118
geometry_msgs::msg::PoseWithCovariance Tracker::getPoseWithCovariance(

0 commit comments

Comments
 (0)