Skip to content

Commit e301a20

Browse files
committed
chore: refactor
Signed-off-by: badai-nguyen <dai.nguyen@tier4.jp>
1 parent f8e6bd0 commit e301a20

8 files changed

+61
-25
lines changed

perception/tensorrt_yolox/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,5 @@ endif()
114114

115115
ament_auto_package(INSTALL_TO_SHARE
116116
launch
117+
config
117118
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/**:
2+
ros__parameters:
3+
# refine segmentation mask by overlay roi class
4+
# disable when sematic segmentation accuracy is good enough
5+
is_roi_overlap_segment: true
6+
7+
# minimum existence_probability of detected roi considered to replace segmentation
8+
overlap_roi_score_threshold: 0.3
9+
10+
# publish color mask for result visualization
11+
is_publish_color_mask: false
12+
13+
roi_overlay_segment_label:
14+
UNKNOWN : true
15+
CAR : true
16+
TRUCK : true
17+
BUS : true
18+
MOTORCYCLE : true
19+
BICYCLE : true
20+
PEDESTRIAN : true

perception/tensorrt_yolox/include/tensorrt_yolox/tensorrt_yolox.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class TrtYoloX
153153
* @param[in] index multitask index
154154
* @param[in] colormap colormap for masks
155155
*/
156-
void getColorizedMask(
156+
void getColorizedMask(
157157
const std::vector<tensorrt_yolox::Colormap> & colormap, const cv::Mat & mask,
158158
cv::Mat & colorized_mask);
159159
inline std::vector<Colormap> getColorMap() { return sematic_color_map_; }

perception/tensorrt_yolox/include/tensorrt_yolox/tensorrt_yolox_node.hpp

+15
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
#ifndef TENSORRT_YOLOX__TENSORRT_YOLOX_NODE_HPP_
1616
#define TENSORRT_YOLOX__TENSORRT_YOLOX_NODE_HPP_
1717

18+
#include "object_recognition_utils/object_recognition_utils.hpp"
19+
#include "utils/utils.hpp"
20+
1821
#include <image_transport/image_transport.hpp>
1922
#include <opencv2/opencv.hpp>
2023
#include <rclcpp/rclcpp.hpp>
@@ -68,6 +71,18 @@ class TrtYoloXNode : public rclcpp::Node
6871
bool is_roi_overlap_segment_;
6972
bool is_publish_color_mask_;
7073
float overlap_roi_score_threshold_;
74+
// TODO(badai-nguyen): change to function
75+
std::map<std::string, int> remap_roi_to_semantic_ = {
76+
{"UNKNOWN", 19}, // other
77+
{"ANIMAL", 19}, // other
78+
{"PEDESTRIAN", 11}, // person
79+
{"CAR", 13}, // car
80+
{"TRUCK", 14}, // truck
81+
{"BUS", 15}, // bus
82+
{"BICYCLE", 18}, // bicycle
83+
{"MOTORBIKE", 17}, // motorcycle
84+
};
85+
utils::FilterTargetLabel roi_overlay_segment_labels_;
7186
};
7287

7388
} // namespace tensorrt_yolox

perception/tensorrt_yolox/launch/yolox_s_plus_opt.launch.xml

+2-6
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@
3232
<arg name="calibration_image_list_path" default="" description="Path to a file which contains path to images. Those images will be used for int8 quantization."/>
3333
<arg name="use_decompress" default="true" description="use image decompress"/>
3434
<arg name="build_only" default="false" description="exit after trt engine is built"/>
35-
<arg name="is_roi_overlap_segment" default="true" description="refine segmentation mask by overlay roi class, disable when sematic segmentation accuracy is good enough"/>
36-
<arg name="is_publish_color_mask" default="false" description="publish color mask for result visualization"/>
37-
<arg name="overlap_roi_score_threshold" default="0.3" description="the roi object existance probability threshold that consider to replace segmentation"/>
35+
<arg name="yolox_s_plus_opt_param_path" default="$(find-pkg-share tensorrt_yolox)/config/yolox_s_plus_opt.param.yaml"/>
3836
<node pkg="image_transport_decompressor" exec="image_transport_decompressor_node" name="image_transport_decompressor_node" if="$(var use_decompress)">
3937
<remap from="~/input/compressed_image" to="$(var input/image)/compressed"/>
4038
<remap from="~/output/raw_image" to="$(var input/image)"/>
@@ -59,8 +57,6 @@
5957
<param name="calibration_image_list_path" value="$(var calibration_image_list_path)"/>
6058
<param name="build_only" value="$(var build_only)"/>
6159
<param name="color_map_path" value="$(var model_path)/semseg_color_map.csv"/>
62-
<param name="is_roi_overlap_segment" value="$(var is_roi_overlap_segment)"/>
63-
<param name="is_publish_color_mask" value="$(var is_publish_color_mask)"/>
64-
<param name="overlap_roi_score_threshold" value="$(var overlap_roi_score_threshold)"/>
60+
<param from="$(var yolox_s_plus_opt_param_path)"/>
6561
</node>
6662
</launch>

perception/tensorrt_yolox/package.xml

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
<depend>autoware_auto_perception_msgs</depend>
2323
<depend>cuda_utils</depend>
2424
<depend>cv_bridge</depend>
25+
<depend>detected_object_validation</depend>
2526
<depend>image_transport</depend>
2627
<depend>libopencv-dev</depend>
2728
<depend>object_recognition_utils</depend>

perception/tensorrt_yolox/src/tensorrt_yolox.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -1281,7 +1281,12 @@ void TrtYoloX::getColorizedMask(
12811281
{
12821282
int width = mask.cols;
12831283
int height = mask.rows;
1284-
// TODO: check size of mask and cmask
1284+
if ((cmask.cols != mask.cols) || (cmask.rows != mask.rows)) {
1285+
RCLCPP_WARN_THROTTLE(
1286+
this->get_logger(), *this->get_clock(), 5,
1287+
"the input and output image's size should be the same!");
1288+
return;
1289+
}
12851290
for (int y = 0; y < height; y++) {
12861291
for (int x = 0; x < width; x++) {
12871292
unsigned char id = mask.at<unsigned char>(y, x);

perception/tensorrt_yolox/src/tensorrt_yolox_node.cpp

+15-17
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,17 @@ TrtYoloXNode::TrtYoloXNode(const rclcpp::NodeOptions & node_options)
9898
is_roi_overlap_segment_ = declare_parameter<bool>("is_roi_overlap_segment");
9999
is_publish_color_mask_ = declare_parameter<bool>("is_publish_color_mask");
100100
overlap_roi_score_threshold_ = declare_parameter<float>("overlap_roi_score_threshold");
101+
roi_overlay_segment_labels_.UNKNOWN =
102+
declare_parameter<bool>("roi_overlay_segment_label.UNKNOWN");
103+
roi_overlay_segment_labels_.CAR = declare_parameter<bool>("roi_overlay_segment_label.CAR");
104+
roi_overlay_segment_labels_.TRUCK = declare_parameter<bool>("roi_overlay_segment_label.TRUCK");
105+
roi_overlay_segment_labels_.BUS = declare_parameter<bool>("roi_overlay_segment_label.BUS");
106+
roi_overlay_segment_labels_.MOTORCYCLE =
107+
declare_parameter<bool>("roi_overlay_segment_label.MOTORCYCLE");
108+
roi_overlay_segment_labels_.BICYCLE =
109+
declare_parameter<bool>("roi_overlay_segment_label.BICYCLE");
110+
roi_overlay_segment_labels_.PEDESTRIAN =
111+
declare_parameter<bool>("roi_overlay_segment_label.PEDESTRIAN");
101112
replaceLabelMap();
102113

103114
tensorrt_common::BuildConfig build_config(
@@ -250,23 +261,10 @@ void TrtYoloXNode::replaceLabelMap()
250261

251262
int TrtYoloXNode::mapRoiLabel2SegLabel(const int32_t roi_label_index)
252263
{
253-
switch (roi_label_index) {
254-
case 0:
255-
return 5;
256-
case 1:
257-
return 13;
258-
case 2:
259-
return 14;
260-
case 3:
261-
return 15;
262-
case 4:
263-
return 18;
264-
case 5:
265-
return 17;
266-
case 6:
267-
return 11;
268-
default:
269-
return -1;
264+
if (roi_overlay_segment_labels_.isTarget(static_cast<uint8_t>(roi_label_index))) {
265+
std::string label = label_map_[roi_label_index];
266+
267+
return remap_roi_to_semantic_[label];
270268
}
271269
return -1;
272270
}

0 commit comments

Comments
 (0)