Skip to content

Commit 341d407

Browse files
authored
fix(multi_object_tracker): fix a bug on the existence probability per input channel (#7269)
* fix: set min and max probability Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp> * fix: wrong decay rate Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp> * fix: set minimum probability of 0.001 Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp> * fix: update decay rate to use float instead of double Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp> --------- Signed-off-by: Taekjin LEE <taekjin.lee@tier4.jp>
1 parent f40440f commit 341d407

File tree

1 file changed

+38
-25
lines changed

1 file changed

+38
-25
lines changed

perception/multi_object_tracker/src/tracker/model/tracker_base.cpp

+38-25
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,20 @@
2323

2424
namespace
2525
{
26-
float updateProbability(float prior, float true_positive, float false_positive)
26+
float updateProbability(
27+
const float & prior, const float & true_positive, const float & false_positive)
2728
{
28-
return (prior * true_positive) / (prior * true_positive + (1 - prior) * false_positive);
29+
constexpr float max_updated_probability = 0.999;
30+
constexpr float min_updated_probability = 0.100;
31+
const float probability =
32+
(prior * true_positive) / (prior * true_positive + (1 - prior) * false_positive);
33+
return std::max(std::min(probability, max_updated_probability), min_updated_probability);
34+
}
35+
float decayProbability(const float & prior, const float & delta_time)
36+
{
37+
constexpr float minimum_probability = 0.001;
38+
const float decay_rate = log(0.5f) / 0.3f; // half-life (50% decay) of 0.3s
39+
return std::max(prior * std::exp(decay_rate * delta_time), minimum_probability);
2940
}
3041
} // namespace
3142

@@ -45,7 +56,8 @@ Tracker::Tracker(
4556
std::generate(uuid_.uuid.begin(), uuid_.uuid.end(), bit_eng);
4657

4758
// Initialize existence probabilities
48-
existence_probabilities_.resize(channel_size, 0.0);
59+
existence_probabilities_.resize(channel_size, 0.001);
60+
total_existence_probability_ = 0.001;
4961
}
5062

5163
void Tracker::initializeExistenceProbabilities(
@@ -58,7 +70,10 @@ void Tracker::initializeExistenceProbabilities(
5870
existence_probabilities_[channel_index] = initial_existence_probability;
5971

6072
// total existence probability
61-
total_existence_probability_ = existence_probability;
73+
constexpr float max_probability = 0.999;
74+
constexpr float min_probability = 0.100;
75+
total_existence_probability_ =
76+
std::max(std::min(existence_probability, max_probability), min_probability);
6277
}
6378

6479
bool Tracker::updateWithMeasurement(
@@ -72,27 +87,26 @@ bool Tracker::updateWithMeasurement(
7287
++total_measurement_count_;
7388

7489
// existence probability on each channel
75-
const double delta_time = (measurement_time - last_update_with_measurement_time_).seconds();
76-
const double decay_rate = -log(0.5) / 0.3; // 50% decay in 0.3s
90+
const float delta_time =
91+
std::abs((measurement_time - last_update_with_measurement_time_).seconds());
7792
constexpr float probability_true_detection = 0.9;
7893
constexpr float probability_false_detection = 0.2;
7994

80-
// update measured channel probability
95+
// update measured channel probability without decay
8196
existence_probabilities_[channel_index] = updateProbability(
8297
existence_probabilities_[channel_index], probability_true_detection,
8398
probability_false_detection);
99+
84100
// decay other channel probabilities
85101
for (size_t i = 0; i < existence_probabilities_.size(); ++i) {
86-
if (i == channel_index) {
87-
continue;
102+
if (i != channel_index) {
103+
existence_probabilities_[i] = decayProbability(existence_probabilities_[i], delta_time);
88104
}
89-
existence_probabilities_[i] *= std::exp(decay_rate * delta_time);
90105
}
91106

92107
// update total existence probability
93-
const float & existence_probability_from_object = object.existence_probability;
94108
total_existence_probability_ = updateProbability(
95-
total_existence_probability_, existence_probability_from_object, probability_false_detection);
109+
total_existence_probability_, object.existence_probability, probability_false_detection);
96110
}
97111

98112
last_update_with_measurement_time_ = measurement_time;
@@ -110,12 +124,11 @@ bool Tracker::updateWithoutMeasurement(const rclcpp::Time & now)
110124
++total_no_measurement_count_;
111125
{
112126
// decay existence probability
113-
double const delta_time = (now - last_update_with_measurement_time_).seconds();
114-
const double decay_rate = -log(0.5) / 0.3; // 50% decay in 0.3s
115-
for (size_t i = 0; i < existence_probabilities_.size(); ++i) {
116-
existence_probabilities_[i] *= std::exp(-decay_rate * delta_time);
127+
float const delta_time = (now - last_update_with_measurement_time_).seconds();
128+
for (float & existence_probability : existence_probabilities_) {
129+
existence_probability = decayProbability(existence_probability, delta_time);
117130
}
118-
total_existence_probability_ *= std::exp(-decay_rate * delta_time);
131+
total_existence_probability_ = decayProbability(total_existence_probability_, delta_time);
119132
}
120133

121134
return true;
@@ -134,18 +147,18 @@ void Tracker::updateClassification(
134147
// Parameters
135148
// if the remove_threshold is too high (compare to the gain), the classification will be removed
136149
// immediately
137-
const double gain = 0.05;
138-
constexpr double remove_threshold = 0.001;
150+
const float gain = 0.05;
151+
constexpr float remove_threshold = 0.001;
139152

140153
// Normalization function
141154
auto normalizeProbabilities =
142155
[](std::vector<autoware_perception_msgs::msg::ObjectClassification> & classification) {
143-
double sum = 0.0;
144-
for (const auto & class_ : classification) {
145-
sum += class_.probability;
156+
float sum = 0.0;
157+
for (const auto & a_class : classification) {
158+
sum += a_class.probability;
146159
}
147-
for (auto & class_ : classification) {
148-
class_.probability /= sum;
160+
for (auto & a_class : classification) {
161+
a_class.probability /= sum;
149162
}
150163
};
151164

@@ -175,7 +188,7 @@ void Tracker::updateClassification(
175188
classification_.erase(
176189
std::remove_if(
177190
classification_.begin(), classification_.end(),
178-
[remove_threshold](const auto & class_) { return class_.probability < remove_threshold; }),
191+
[remove_threshold](const auto & a_class) { return a_class.probability < remove_threshold; }),
179192
classification_.end());
180193

181194
// Normalize tracking classification

0 commit comments

Comments
 (0)