Skip to content

Commit d7d6071

Browse files
committedMar 24, 2025·
feat(autoware_trajectory): improve get_underlying_base performance
Signed-off-by: Mamoru Sobue <mamoru.sobue@tier4.jp>
1 parent f713436 commit d7d6071

11 files changed

+104
-62
lines changed
 

‎common/autoware_trajectory/include/autoware/trajectory/detail/helpers.hpp

-24
Original file line numberDiff line numberDiff line change
@@ -16,36 +16,12 @@
1616
#define AUTOWARE__TRAJECTORY__DETAIL__HELPERS_HPP_
1717

1818
#include <cstddef>
19-
#include <set>
2019
#include <vector>
2120

2221
namespace autoware::trajectory::detail
2322
{
2423
inline namespace helpers
2524
{
26-
/**
27-
* @brief Merge multiple vectors into one, keeping only unique elements.
28-
* @tparam Vectors Variadic template parameter for vector types.
29-
* @param vectors Vectors to be merged.
30-
* @return std::vector<double> Merged vector with unique elements.
31-
*/
32-
template <typename... Vectors>
33-
std::vector<double> merge_vectors(const Vectors &... vectors)
34-
{
35-
std::set<double> unique_elements;
36-
37-
// Helper function to insert elements into the set
38-
auto insert_elements = [&unique_elements](const auto & vec) {
39-
unique_elements.insert(vec.begin(), vec.end());
40-
};
41-
42-
// Expand the parameter pack and insert elements from each vector
43-
(insert_elements(vectors), ...);
44-
45-
// Convert the set to std::vector<double>
46-
return {unique_elements.begin(), unique_elements.end()};
47-
}
48-
4925
/**
5026
* @brief Ensures the output vector has at least a specified number of points by linearly
5127
* interpolating values between each input intervals

‎common/autoware_trajectory/include/autoware/trajectory/detail/interpolated_array.hpp

+16-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "autoware/trajectory/detail/logging.hpp"
1919
#include "autoware/trajectory/interpolator/interpolator.hpp"
2020

21-
#include <Eigen/Core>
2221
#include <rclcpp/logging.hpp>
2322

2423
#include <algorithm>
@@ -42,6 +41,7 @@ class InterpolatedArray
4241
std::vector<double> bases_;
4342
std::vector<T> values_;
4443
std::shared_ptr<interpolator::InterpolatorInterface<T>> interpolator_;
44+
std::function<void(const double s)> base_addition_callback_slot_{nullptr};
4545

4646
public:
4747
/**
@@ -97,13 +97,22 @@ class InterpolatedArray
9797
bases_ = other.bases_;
9898
values_ = other.values_;
9999
interpolator_ = other.interpolator_->clone();
100+
base_addition_callback_slot_ = other.base_addition_callback_slot_;
100101
}
101102
return *this;
102103
}
103104

104105
// Destructor
105106
~InterpolatedArray() = default;
106107

108+
/**
109+
* @brief add the callback function to be executed when a new base is added to this class
110+
*/
111+
void connect_base_addition_callback(std::function<void(const double s)> && signal)
112+
{
113+
base_addition_callback_slot_ = std::move(signal);
114+
}
115+
107116
/**
108117
* @brief Get the start value of the base.
109118
* @return The start value.
@@ -143,6 +152,12 @@ class InterpolatedArray
143152
return index;
144153
} // Insert into bases
145154
bases.insert(it, val);
155+
156+
// execute the callback to notify that a new base has been added
157+
if (parent_.base_addition_callback_slot_) {
158+
std::invoke(parent_.base_addition_callback_slot_, value);
159+
}
160+
146161
// Insert into values at the corresponding position
147162
values.insert(values.begin() + index, value);
148163
return index;
@@ -211,12 +226,6 @@ class InterpolatedArray
211226
* @return The interpolated value.
212227
*/
213228
T compute(const double x) const { return interpolator_->compute(x); }
214-
215-
/**
216-
* @brief Get the underlying data of the array.
217-
* @return A pair containing the axis and values.
218-
*/
219-
std::pair<std::vector<double>, std::vector<T>> get_data() const { return {bases_, values_}; }
220229
};
221230

222231
} // namespace autoware::trajectory::detail

‎common/autoware_trajectory/include/autoware/trajectory/path_point.hpp

+9
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ class Trajectory<autoware_planning_msgs::msg::PathPoint>
4242
std::shared_ptr<detail::InterpolatedArray<double>> heading_rate_rps_{
4343
nullptr}; //!< Heading rate in rad/s};
4444

45+
/**
46+
* @brief add the event function to
47+
* longitudinal_velocity_mps/lateral_velocity_mps/heading_rate_mps interpolator using observer
48+
* pattern
49+
* @note when a new base is added to longitudinal_velocity_mps for example, the addition is also
50+
* notified and update_base() is triggered.
51+
*/
52+
virtual void add_base_addition_callback();
53+
4554
public:
4655
Trajectory();
4756
~Trajectory() override = default;

‎common/autoware_trajectory/include/autoware/trajectory/path_point_with_lane_id.hpp

+7
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ class Trajectory<autoware_internal_planning_msgs::msg::PathPointWithLaneId>
3636
protected:
3737
std::shared_ptr<detail::InterpolatedArray<LaneIdType>> lane_ids_{nullptr}; //!< Lane ID
3838

39+
/**
40+
* @brief add the event function to lane_ids additionally
41+
* @note when a new base is added to lane_ids for example, the addition is also
42+
* notified and update_base() is triggered.
43+
*/
44+
void add_base_addition_callback() override;
45+
3946
public:
4047
Trajectory();
4148
~Trajectory() override = default;

‎common/autoware_trajectory/include/autoware/trajectory/point.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ class Trajectory<geometry_msgs::msg::Point>
4949

5050
double start_{0.0}, end_{0.0}; //!< Start and end of the arc length of the trajectory
5151

52+
/**
53+
* @brief add the input s if it is not contained in bases_
54+
*/
55+
void update_bases(const double s);
56+
5257
/**
5358
* @brief Validate the arc length is within the trajectory
5459
* @param s Arc length

‎common/autoware_trajectory/include/autoware/trajectory/trajectory_point.hpp

+9
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ class Trajectory<autoware_planning_msgs::msg::TrajectoryPoint>
4747
std::shared_ptr<detail::InterpolatedArray<double>> rear_wheel_angle_rad_{
4848
nullptr}; //!< Rear wheel angle in rad} Warning, this is not used
4949

50+
/**
51+
* @brief add the event function to
52+
* longitudinal_velocity_mps/lateral_velocity_mps/heading_rate_mps/acceleration_mps2/front_wheel_angle_rad/rear_wheel_angle_rad
53+
* interpolator
54+
* @note when a new base is added to longitudinal_velocity_mps for example, the addition is also
55+
* notified and update_base() is triggered.
56+
*/
57+
virtual void add_base_addition_callback();
58+
5059
public:
5160
Trajectory();
5261
~Trajectory() override = default;

‎common/autoware_trajectory/src/path_point.cpp

+14-10
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,20 @@ namespace autoware::trajectory
2929

3030
using PointType = autoware_planning_msgs::msg::PathPoint;
3131

32+
void Trajectory<PointType>::add_base_addition_callback()
33+
{
34+
longitudinal_velocity_mps_->connect_base_addition_callback(
35+
[&](const double s) { return this->update_bases(s); });
36+
lateral_velocity_mps_->connect_base_addition_callback(
37+
[&](const double s) { return this->update_bases(s); });
38+
heading_rate_rps_->connect_base_addition_callback(
39+
[&](const double s) { return this->update_bases(s); });
40+
}
41+
3242
Trajectory<PointType>::Trajectory()
3343
{
3444
Builder::defaults(this);
45+
add_base_addition_callback();
3546
}
3647

3748
Trajectory<PointType>::Trajectory(const Trajectory & rhs)
@@ -42,6 +53,7 @@ Trajectory<PointType>::Trajectory(const Trajectory & rhs)
4253
std::make_shared<detail::InterpolatedArray<double>>(*rhs.lateral_velocity_mps_)),
4354
heading_rate_rps_(std::make_shared<detail::InterpolatedArray<double>>(*rhs.heading_rate_rps_))
4455
{
56+
add_base_addition_callback();
4557
}
4658

4759
Trajectory<PointType> & Trajectory<PointType>::operator=(const Trajectory & rhs)
@@ -52,6 +64,7 @@ Trajectory<PointType> & Trajectory<PointType>::operator=(const Trajectory & rhs)
5264
*lateral_velocity_mps_ = *rhs.lateral_velocity_mps_;
5365
*heading_rate_rps_ = *rhs.heading_rate_rps_;
5466
}
67+
add_base_addition_callback();
5568
return *this;
5669
}
5770

@@ -98,16 +111,7 @@ interpolator::InterpolationResult Trajectory<PointType>::build(
98111

99112
std::vector<double> Trajectory<PointType>::get_underlying_bases() const
100113
{
101-
auto get_bases = [](const auto & interpolated_array) {
102-
auto [bases, values] = interpolated_array.get_data();
103-
return bases;
104-
};
105-
106-
auto bases = detail::merge_vectors(
107-
bases_, get_bases(this->longitudinal_velocity_mps()), get_bases(this->lateral_velocity_mps()),
108-
get_bases(this->heading_rate_rps()));
109-
110-
bases = detail::crop_bases(bases, start_, end_);
114+
auto bases = detail::crop_bases(bases_, start_, end_);
111115
std::transform(
112116
bases.begin(), bases.end(), bases.begin(), [this](const double & s) { return s - start_; });
113117
return bases;

‎common/autoware_trajectory/src/path_point_with_lane_id.cpp

+9-10
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,16 @@ namespace autoware::trajectory
2626

2727
using PointType = autoware_internal_planning_msgs::msg::PathPointWithLaneId;
2828

29+
void Trajectory<PointType>::add_base_addition_callback()
30+
{
31+
BaseClass::add_base_addition_callback();
32+
lane_ids_->connect_base_addition_callback([&](const double s) { return this->update_bases(s); });
33+
}
34+
2935
Trajectory<PointType>::Trajectory()
3036
{
3137
Builder::defaults(this);
38+
add_base_addition_callback();
3239
}
3340

3441
Trajectory<PointType> & Trajectory<PointType>::operator=(const Trajectory & rhs)
@@ -37,6 +44,7 @@ Trajectory<PointType> & Trajectory<PointType>::operator=(const Trajectory & rhs)
3744
BaseClass::operator=(rhs);
3845
lane_ids_ = std::make_shared<detail::InterpolatedArray<LaneIdType>>(this->lane_ids());
3946
}
47+
add_base_addition_callback();
4048
return *this;
4149
}
4250

@@ -66,16 +74,7 @@ interpolator::InterpolationResult Trajectory<PointType>::build(
6674

6775
std::vector<double> Trajectory<PointType>::get_underlying_bases() const
6876
{
69-
auto get_bases = [](const auto & interpolated_array) {
70-
auto [bases, values] = interpolated_array.get_data();
71-
return bases;
72-
};
73-
74-
auto bases = detail::merge_vectors(
75-
bases_, get_bases(this->longitudinal_velocity_mps()), get_bases(this->lateral_velocity_mps()),
76-
get_bases(this->heading_rate_rps()), get_bases(this->lane_ids()));
77-
78-
bases = detail::crop_bases(bases, start_, end_);
77+
auto bases = detail::crop_bases(bases_, start_, end_);
7978

8079
std::transform(
8180
bases.begin(), bases.end(), bases.begin(), [this](const double & s) { return s - start_; });

‎common/autoware_trajectory/src/point.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,20 @@ std::vector<double> Trajectory<PointType>::get_underlying_bases() const
124124
return bases;
125125
}
126126

127+
void Trajectory<PointType>::update_bases(const double s)
128+
{
129+
const auto it = std::lower_bound(bases_.begin(), bases_.end(), s);
130+
if (it == bases_.end()) {
131+
// NOTE(soblin): the extension of base(or extrapolation) will be supported by other API.
132+
return;
133+
}
134+
if (*it == s) {
135+
// already inserted
136+
return;
137+
}
138+
bases_.insert(it, s);
139+
}
140+
127141
double Trajectory<PointType>::length() const
128142
{
129143
return end_ - start_;

‎common/autoware_trajectory/src/pose.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ PointType Trajectory<PointType>::compute(const double s) const
8888
PointType result;
8989
result.position = BaseClass::compute(s);
9090
const auto s_clamp = clamp(s);
91+
// NOTE(soblin): Ideally azimuth() should be used? But okay if the interpolation is not rough
9192
result.orientation = orientation_interpolator_->compute(s_clamp);
9293
return result;
9394
}

‎common/autoware_trajectory/src/trajectory_point.cpp

+20-11
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,26 @@ namespace autoware::trajectory
3131

3232
using PointType = autoware_planning_msgs::msg::TrajectoryPoint;
3333

34+
void Trajectory<PointType>::add_base_addition_callback()
35+
{
36+
longitudinal_velocity_mps_->connect_base_addition_callback(
37+
[&](const double s) { return this->update_bases(s); });
38+
lateral_velocity_mps_->connect_base_addition_callback(
39+
[&](const double s) { return this->update_bases(s); });
40+
heading_rate_rps_->connect_base_addition_callback(
41+
[&](const double s) { return this->update_bases(s); });
42+
acceleration_mps2_->connect_base_addition_callback(
43+
[&](const double s) { return this->update_bases(s); });
44+
front_wheel_angle_rad_->connect_base_addition_callback(
45+
[&](const double s) { return this->update_bases(s); });
46+
rear_wheel_angle_rad_->connect_base_addition_callback(
47+
[&](const double s) { return this->update_bases(s); });
48+
}
49+
3450
Trajectory<PointType>::Trajectory()
3551
{
3652
Builder::defaults(this);
53+
add_base_addition_callback();
3754
}
3855

3956
Trajectory<PointType>::Trajectory(const Trajectory & rhs)
@@ -49,6 +66,7 @@ Trajectory<PointType>::Trajectory(const Trajectory & rhs)
4966
rear_wheel_angle_rad_(
5067
std::make_shared<detail::InterpolatedArray<double>>(*rhs.rear_wheel_angle_rad_))
5168
{
69+
add_base_addition_callback();
5270
}
5371

5472
Trajectory<PointType> & Trajectory<PointType>::operator=(const Trajectory & rhs)
@@ -62,6 +80,7 @@ Trajectory<PointType> & Trajectory<PointType>::operator=(const Trajectory & rhs)
6280
*front_wheel_angle_rad_ = *rhs.front_wheel_angle_rad_;
6381
*rear_wheel_angle_rad_ = *rhs.rear_wheel_angle_rad_;
6482
}
83+
add_base_addition_callback();
6584
return *this;
6685
}
6786

@@ -141,17 +160,7 @@ interpolator::InterpolationResult Trajectory<PointType>::build(
141160

142161
std::vector<double> Trajectory<PointType>::get_underlying_bases() const
143162
{
144-
auto get_bases = [](const auto & interpolated_array) {
145-
auto [bases, values] = interpolated_array.get_data();
146-
return bases;
147-
};
148-
149-
auto bases = detail::merge_vectors(
150-
bases_, get_bases(this->longitudinal_velocity_mps()), get_bases(this->lateral_velocity_mps()),
151-
get_bases(this->heading_rate_rps()), get_bases(this->acceleration_mps2()),
152-
get_bases(this->front_wheel_angle_rad()), get_bases(this->rear_wheel_angle_rad()));
153-
154-
bases = detail::crop_bases(bases, start_, end_);
163+
auto bases = detail::crop_bases(bases_, start_, end_);
155164
std::transform(
156165
bases.begin(), bases.end(), bases.begin(), [this](const double & s) { return s - start_; });
157166
return bases;

0 commit comments

Comments
 (0)