Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(multi_object_tracker): add object class filtering in tracking process #6607

Merged
Prev Previous commit
fix: revise the filtering process flow
Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp>
technolojin committed Mar 15, 2024
commit 6c00b63f137fccf80c09b06297d8f5619351e776
59 changes: 33 additions & 26 deletions perception/multi_object_tracker/src/tracker/model/tracker_base.cpp
Original file line number Diff line number Diff line change
@@ -58,30 +58,36 @@ void Tracker::updateClassification(
const std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification)
{
// classification algorithm:
// 0. Remove the class with probability < 0.005
// 1. Decay all existing probability
// 2. Update the matched classification probability with a gain
// 3. If the label is not found, add it to the classification list
// 4. Normalize
// 0. Normalize the input classification
// 1-1. Update the matched classification probability with a gain (ratio of 0.05)
// 1-2. If the label is not found, add it to the classification list
// 2. Remove the class with probability < remove_threshold (0.001)
// 3. Normalize tracking classification

// Gain and decay
// Parameters
// if the remove_threshold is too high (compare to the gain), the classification will be removed
// immediately
const double gain = 0.05;
const double decay = 1.0 - gain;
constexpr double remove_threshold = 0.001;

// If the probability is less than 0.005, remove the class
classification_.erase(
std::remove_if(
classification_.begin(), classification_.end(),
[](const auto & class_) { return class_.probability < 0.005; }),
classification_.end());
// Normalization function
auto normalizeProbabilities =
[](std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification) {
double sum = 0.0;
for (const auto & class_ : classification) {
sum += class_.probability;
}
for (auto & class_ : classification) {
class_.probability /= sum;
}
};

// Decay all existing probability
for (auto & class_ : classification_) {
class_.probability *= decay;
}
// Normalize the input
auto classification_input = classification;
normalizeProbabilities(classification_input);

// Update the matched classification probability with a gain
for (const auto & new_class : classification) {
for (const auto & new_class : classification_input) {
bool found = false;
for (auto & old_class : classification_) {
if (new_class.label == old_class.label) {
@@ -98,14 +104,15 @@ void Tracker::updateClassification(
}
}

// Normalize
double sum = 0.0;
for (const auto & class_ : classification_) {
sum += class_.probability;
}
for (auto & class_ : classification_) {
class_.probability /= sum;
}
// If the probability is less than the threshold, remove the class
classification_.erase(
std::remove_if(
classification_.begin(), classification_.end(),
[remove_threshold](const auto & class_) { return class_.probability < remove_threshold; }),
classification_.end());

// Normalize tracking classification
normalizeProbabilities(classification_);
}

geometry_msgs::msg::PoseWithCovariance Tracker::getPoseWithCovariance(