Skip to content

Commit 3256acf

Browse files
feat(map_based_prediction): consider crosswalks signals (#6189)
* consider the crosswalks signals * update with the reviewers comments Signed-off-by: Yuki Takagi <yuki.takagi@tier4.jp>
1 parent 2b1a33f commit 3256acf

File tree

5 files changed

+74
-0
lines changed

5 files changed

+74
-0
lines changed

perception/map_based_prediction/config/map_based_prediction.param.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
use_vehicle_acceleration: false # whether to consider current vehicle acceleration when predicting paths or not
2020
speed_limit_multiplier: 1.5 # When using vehicle acceleration. Set vehicle's maximum predicted speed as the legal speed limit in that lanelet times this value
2121
acceleration_exponential_half_life: 2.5 # [s] When using vehicle acceleration. The decaying acceleration model considers that the current vehicle acceleration will be halved after this many seconds
22+
use_crosswalk_signal: true
2223
# parameter for shoulder lane prediction
2324
prediction_time_horizon_rate_for_validate_shoulder_lane_length: 0.8
2425

perception/map_based_prediction/include/map_based_prediction/map_based_prediction_node.hpp

+13
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <autoware_auto_mapping_msgs/msg/had_map_bin.hpp>
3030
#include <autoware_auto_perception_msgs/msg/predicted_objects.hpp>
3131
#include <autoware_auto_perception_msgs/msg/tracked_objects.hpp>
32+
#include <autoware_perception_msgs/msg/traffic_signal_array.hpp>
3233
#include <geometry_msgs/msg/pose.hpp>
3334
#include <geometry_msgs/msg/pose_stamped.hpp>
3435
#include <geometry_msgs/msg/twist.hpp>
@@ -42,6 +43,7 @@
4243
#include <algorithm>
4344
#include <deque>
4445
#include <memory>
46+
#include <optional>
4547
#include <string>
4648
#include <unordered_map>
4749
#include <utility>
@@ -107,6 +109,9 @@ using autoware_auto_perception_msgs::msg::TrackedObject;
107109
using autoware_auto_perception_msgs::msg::TrackedObjectKinematics;
108110
using autoware_auto_perception_msgs::msg::TrackedObjects;
109111
using autoware_auto_planning_msgs::msg::TrajectoryPoint;
112+
using autoware_perception_msgs::msg::TrafficSignal;
113+
using autoware_perception_msgs::msg::TrafficSignalArray;
114+
using autoware_perception_msgs::msg::TrafficSignalElement;
110115
using tier4_autoware_utils::StopWatch;
111116
using tier4_debug_msgs::msg::StringStamped;
112117
using TrajectoryPoints = std::vector<TrajectoryPoint>;
@@ -122,6 +127,7 @@ class MapBasedPredictionNode : public rclcpp::Node
122127
rclcpp::Publisher<StringStamped>::SharedPtr pub_calculation_time_;
123128
rclcpp::Subscription<TrackedObjects>::SharedPtr sub_objects_;
124129
rclcpp::Subscription<HADMapBin>::SharedPtr sub_map_;
130+
rclcpp::Subscription<TrafficSignalArray>::SharedPtr sub_traffic_signals_;
125131

126132
// Object History
127133
std::unordered_map<std::string, std::deque<ObjectData>> objects_history_;
@@ -131,6 +137,8 @@ class MapBasedPredictionNode : public rclcpp::Node
131137
std::shared_ptr<lanelet::routing::RoutingGraph> routing_graph_ptr_;
132138
std::shared_ptr<lanelet::traffic_rules::TrafficRules> traffic_rules_ptr_;
133139

140+
std::unordered_map<lanelet::Id, TrafficSignal> traffic_signal_id_map_;
141+
134142
// parameter update
135143
OnSetParametersCallbackHandle::SharedPtr set_param_res_;
136144
rcl_interfaces::msg::SetParametersResult onParam(
@@ -181,11 +189,14 @@ class MapBasedPredictionNode : public rclcpp::Node
181189
double speed_limit_multiplier_;
182190
double acceleration_exponential_half_life_;
183191

192+
bool use_crosswalk_signal_;
193+
184194
// Stop watch
185195
StopWatch<std::chrono::milliseconds> stop_watch_;
186196

187197
// Member Functions
188198
void mapCallback(const HADMapBin::ConstSharedPtr msg);
199+
void trafficSignalsCallback(const TrafficSignalArray::ConstSharedPtr msg);
189200
void objectsCallback(const TrackedObjects::ConstSharedPtr in_objects);
190201

191202
bool doesPathCrossAnyFence(const PredictedPath & predicted_path);
@@ -249,6 +260,8 @@ class MapBasedPredictionNode : public rclcpp::Node
249260
const LaneletsData & lanelets_data);
250261
bool isDuplicated(
251262
const PredictedPath & predicted_path, const std::vector<PredictedPath> & predicted_paths);
263+
std::optional<lanelet::Id> getTrafficSignalId(const lanelet::ConstLanelet & way_lanelet);
264+
std::optional<TrafficSignalElement> getTrafficSignalElement(const lanelet::Id & id);
252265

253266
visualization_msgs::msg::Marker getDebugMarker(
254267
const TrackedObject & object, const Maneuver & maneuver, const size_t obj_num);

perception/map_based_prediction/launch/map_based_prediction.launch.xml

+2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
<arg name="param_path" default="$(find-pkg-share map_based_prediction)/config/map_based_prediction.param.yaml"/>
44

55
<arg name="vector_map_topic" default="/map/vector_map"/>
6+
<arg name="traffic_signals_topic" default="/perception/traffic_light_recognition/traffic_signals"/>
67
<arg name="output_topic" default="objects"/>
78
<arg name="input_topic" default="/perception/object_recognition/tracking/objects"/>
89

910
<node pkg="map_based_prediction" exec="map_based_prediction" name="map_based_prediction" output="screen">
1011
<param from="$(var param_path)"/>
1112
<remap from="/vector_map" to="$(var vector_map_topic)"/>
13+
<remap from="/traffic_signals" to="$(var traffic_signals_topic)"/>
1214
<remap from="~/output/objects" to="$(var output_topic)"/>
1315
<remap from="~/input/objects" to="$(var input_topic)"/>
1416
</node>

perception/map_based_prediction/package.xml

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
<buildtool_depend>autoware_cmake</buildtool_depend>
1717

1818
<depend>autoware_auto_perception_msgs</depend>
19+
<depend>autoware_perception_msgs</depend>
1920
<depend>interpolation</depend>
2021
<depend>lanelet2_extension</depend>
2122
<depend>libgoogle-glog-dev</depend>

perception/map_based_prediction/src/map_based_prediction_node.cpp

+57
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,8 @@ MapBasedPredictionNode::MapBasedPredictionNode(const rclcpp::NodeOptions & node_
788788
acceleration_exponential_half_life_ =
789789
declare_parameter<double>("acceleration_exponential_half_life");
790790

791+
use_crosswalk_signal_ = declare_parameter<bool>("use_crosswalk_signal");
792+
791793
path_generator_ = std::make_shared<PathGenerator>(
792794
prediction_time_horizon_, lateral_control_time_horizon_, prediction_sampling_time_interval_,
793795
min_crosswalk_user_velocity_);
@@ -801,6 +803,9 @@ MapBasedPredictionNode::MapBasedPredictionNode(const rclcpp::NodeOptions & node_
801803
sub_map_ = this->create_subscription<HADMapBin>(
802804
"/vector_map", rclcpp::QoS{1}.transient_local(),
803805
std::bind(&MapBasedPredictionNode::mapCallback, this, std::placeholders::_1));
806+
sub_traffic_signals_ = this->create_subscription<TrafficSignalArray>(
807+
"/traffic_signals", 1,
808+
std::bind(&MapBasedPredictionNode::trafficSignalsCallback, this, std::placeholders::_1));
804809

805810
pub_objects_ = this->create_publisher<PredictedObjects>("~/output/objects", rclcpp::QoS{1});
806811
pub_debug_markers_ =
@@ -872,6 +877,14 @@ void MapBasedPredictionNode::mapCallback(const HADMapBin::ConstSharedPtr msg)
872877
crosswalks_.insert(crosswalks_.end(), walkways.begin(), walkways.end());
873878
}
874879

880+
void MapBasedPredictionNode::trafficSignalsCallback(const TrafficSignalArray::ConstSharedPtr msg)
881+
{
882+
traffic_signal_id_map_.clear();
883+
for (const auto & signal : msg->signals) {
884+
traffic_signal_id_map_[signal.traffic_signal_id] = signal;
885+
}
886+
}
887+
875888
void MapBasedPredictionNode::objectsCallback(const TrackedObjects::ConstSharedPtr in_objects)
876889
{
877890
stop_watch_.tic();
@@ -1218,6 +1231,18 @@ PredictedObject MapBasedPredictionNode::getPredictedObjectAsCrosswalkUser(
12181231
}
12191232
// try to find the edge points for all crosswalks and generate path to the crosswalk edge
12201233
for (const auto & crosswalk : crosswalks_) {
1234+
const auto crosswalk_signal_id_opt = getTrafficSignalId(crosswalk);
1235+
if (crosswalk_signal_id_opt.has_value() && use_crosswalk_signal_) {
1236+
const auto signal_color = [&] {
1237+
const auto elem_opt = getTrafficSignalElement(crosswalk_signal_id_opt.value());
1238+
return elem_opt ? elem_opt.value().color : TrafficSignalElement::UNKNOWN;
1239+
}();
1240+
1241+
if (signal_color == TrafficSignalElement::RED) {
1242+
continue;
1243+
}
1244+
}
1245+
12211246
const auto edge_points = getCrosswalkEdgePoints(crosswalk);
12221247

12231248
const auto reachable_first = hasPotentialToReach(
@@ -2211,6 +2236,38 @@ bool MapBasedPredictionNode::isDuplicated(
22112236

22122237
return false;
22132238
}
2239+
2240+
std::optional<lanelet::Id> MapBasedPredictionNode::getTrafficSignalId(
2241+
const lanelet::ConstLanelet & way_lanelet)
2242+
{
2243+
const auto traffic_light_reg_elems =
2244+
way_lanelet.regulatoryElementsAs<const lanelet::TrafficLight>();
2245+
if (traffic_light_reg_elems.empty()) {
2246+
return std::nullopt;
2247+
} else if (traffic_light_reg_elems.size() > 1) {
2248+
RCLCPP_ERROR(
2249+
get_logger(),
2250+
"[Map Based Prediction]: "
2251+
"Multiple regulatory elements as TrafficLight are defined to one lanelet object.");
2252+
}
2253+
return traffic_light_reg_elems.front()->id();
2254+
}
2255+
2256+
std::optional<TrafficSignalElement> MapBasedPredictionNode::getTrafficSignalElement(
2257+
const lanelet::Id & id)
2258+
{
2259+
if (traffic_signal_id_map_.count(id) != 0) {
2260+
const auto & signal_elements = traffic_signal_id_map_.at(id).elements;
2261+
if (signal_elements.size() > 1) {
2262+
RCLCPP_ERROR(
2263+
get_logger(), "[Map Based Prediction]: Multiple TrafficSignalElement_ are received.");
2264+
} else if (!signal_elements.empty()) {
2265+
return signal_elements.front();
2266+
}
2267+
}
2268+
return std::nullopt;
2269+
}
2270+
22142271
} // namespace map_based_prediction
22152272

22162273
#include <rclcpp_components/register_node_macro.hpp>

0 commit comments

Comments
 (0)