Skip to content

Commit aede063

Browse files
committed
fix: add roi overlapping segment
Signed-off-by: badai-nguyen <dai.nguyen@tier4.jp>
1 parent 07788b5 commit aede063

File tree

6 files changed

+129
-56
lines changed

6 files changed

+129
-56
lines changed

perception/tensorrt_yolox/include/tensorrt_yolox/tensorrt_yolox.hpp

+13-7
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class TrtYoloX
7878
* @param[in] build_config configuration including precision, calibration method, DLA, remaining
7979
* fp16 for first layer, remaining fp16 for last layer and profiler for builder
8080
* @param[in] use_gpu_preprocess whether use cuda gpu for preprocessing
81+
* @param[in] publish_color_mask whether publish color_mask for debugging and visualization
8182
* @param[in] calibration_image_list_file path for calibration files (only require for
8283
* quantization)
8384
* @param[in] norm_factor scaling factor for preprocess
@@ -90,8 +91,9 @@ class TrtYoloX
9091
const std::string & color_map_path, const int num_class = 8, const float score_threshold = 0.3,
9192
const float nms_threshold = 0.7,
9293
const tensorrt_common::BuildConfig build_config = tensorrt_common::BuildConfig(),
93-
const bool use_gpu_preprocess = false, std::string calibration_image_list_file = std::string(),
94-
const double norm_factor = 1.0, [[maybe_unused]] const std::string & cache_dir = "",
94+
const bool use_gpu_preprocess = false, const bool publish_color_mask = false,
95+
std::string calibration_image_list_file = std::string(), const double norm_factor = 1.0,
96+
[[maybe_unused]] const std::string & cache_dir = "",
9597
const tensorrt_common::BatchConfig & batch_config = {1, 1, 1},
9698
const size_t max_workspace_size = (1 << 30));
9799
/**
@@ -105,8 +107,8 @@ class TrtYoloX
105107
* @param[in] images batched images
106108
*/
107109
bool doInference(
108-
const std::vector<cv::Mat> & images, ObjectArrays & objects, cv::Mat & mask,
109-
cv::Mat & color_mask);
110+
const std::vector<cv::Mat> & images, ObjectArrays & objects, std::vector<cv::Mat> & masks,
111+
std::vector<cv::Mat> & color_masks);
110112

111113
/**
112114
* @brief run inference including pre-process and post-process
@@ -201,8 +203,8 @@ class TrtYoloX
201203

202204
bool feedforward(const std::vector<cv::Mat> & images, ObjectArrays & objects);
203205
bool feedforwardAndDecode(
204-
const std::vector<cv::Mat> & images, ObjectArrays & objects, cv::Mat & mask,
205-
cv::Mat & color_mask);
206+
const std::vector<cv::Mat> & images, ObjectArrays & objects, std::vector<cv::Mat> & masks,
207+
std::vector<cv::Mat> & color_masks);
206208
void decodeOutputs(float * prob, ObjectArray & objects, float scale, cv::Size & img_size) const;
207209
void generateGridsAndStride(
208210
const int target_w, const int target_h, const std::vector<int> & strides,
@@ -307,7 +309,11 @@ class TrtYoloX
307309
CudaUniquePtrHost<unsigned char[]> argmax_buf_h_;
308310
// device buffer for argmax postprocessing on GPU
309311
CudaUniquePtr<unsigned char[]> argmax_buf_d_;
310-
std::vector<tensorrt_yolox::Colormap> color_map_;
312+
std::vector<tensorrt_yolox::Colormap> sematic_color_map_;
313+
// flag whether overlay segmentation by roi
314+
bool roi_overlap_segment_;
315+
// flag where publish color mask for debugging and visualization
316+
bool publish_color_mask_;
311317
};
312318

313319
} // namespace tensorrt_yolox

perception/tensorrt_yolox/include/tensorrt_yolox/tensorrt_yolox_node.hpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ class TrtYoloXNode : public rclcpp::Node
5151
void onImage(const sensor_msgs::msg::Image::ConstSharedPtr msg);
5252
bool readLabelFile(const std::string & label_path);
5353
void replaceLabelMap();
54-
54+
void overlapSegmentByRoi(const tensorrt_yolox::Object & object, cv::Mat & mask);
55+
int mapRoiLabel2SegLabel(const int32_t roi_label_index);
5556
image_transport::Publisher image_pub_;
5657
image_transport::Publisher mask_pub_;
5758
image_transport::Publisher color_mask_pub_;
@@ -64,6 +65,9 @@ class TrtYoloXNode : public rclcpp::Node
6465

6566
LabelMap label_map_;
6667
std::unique_ptr<tensorrt_yolox::TrtYoloX> trt_yolox_;
68+
bool is_roi_overlap_segment_;
69+
bool is_publish_color_mask_;
70+
float overlap_roi_score_threshold_;
6771
};
6872

6973
} // namespace tensorrt_yolox

perception/tensorrt_yolox/launch/yolox_s_plus_opt.launch.xml

+7-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
<arg name="calibration_image_list_path" default="" description="Path to a file which contains path to images. Those images will be used for int8 quantization."/>
3333
<arg name="use_decompress" default="true" description="use image decompress"/>
3434
<arg name="build_only" default="false" description="exit after trt engine is built"/>
35-
35+
<arg name="is_roi_overlap_segment" default="true" description="refine segmentation mask by overlay roi class, disable when sematic segmentation accuracy is good enough"/>
36+
<arg name="is_publish_color_mask" default="false" description="publish color mask for result visualization"/>
37+
<arg name="overlap_roi_score_threshold" default="0.3" description="the roi object existance probability threshold that consider to replace segmentation"/>
3638
<node pkg="image_transport_decompressor" exec="image_transport_decompressor_node" name="image_transport_decompressor_node" if="$(var use_decompress)">
3739
<remap from="~/input/compressed_image" to="$(var input/image)/compressed"/>
3840
<remap from="~/output/raw_image" to="$(var input/image)"/>
@@ -56,6 +58,9 @@
5658
<param name="preprocess_on_gpu" value="$(var preprocess_on_gpu)"/>
5759
<param name="calibration_image_list_path" value="$(var calibration_image_list_path)"/>
5860
<param name="build_only" value="$(var build_only)"/>
59-
<param name="color_map_path" value="$(var model_path)/bdd100k_semseg.csv"/>
61+
<param name="color_map_path" value="$(var model_path)/semseg_color_map.csv"/>
62+
<param name="is_roi_overlap_segment" value="$(var is_roi_overlap_segment)"/>
63+
<param name="is_publish_color_mask" value="$(var is_publish_color_mask)"/>
64+
<param name="overlap_roi_score_threshold" value="$(var overlap_roi_score_threshold)"/>
6065
</node>
6166
</launch>

perception/tensorrt_yolox/src/tensorrt_yolox.cpp

+42-26
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ std::vector<tensorrt_yolox::Colormap> get_seg_colormap(const std::string & filen
109109
std::vector<tensorrt_yolox::Colormap> seg_cmap;
110110
if (filename != "not-specified") {
111111
std::vector<std::string> color_list = loadListFromTextFile(filename);
112-
for (int i = 0; i < (int)color_list.size(); i++) {
112+
for (int i = 0; i < static_cast<int>(color_list.size()); i++) {
113113
if (i == 0) {
114114
// Skip header
115115
continue;
@@ -120,7 +120,7 @@ std::vector<tensorrt_yolox::Colormap> get_seg_colormap(const std::string & filen
120120
size_t npos = colormapString.find_first_of(',');
121121
assert(npos != std::string::npos);
122122
std::string substr = colormapString.substr(0, npos);
123-
int id = (int)std::stoi(trim(substr));
123+
int id = static_cast<int>(std::stoi(trim(substr)));
124124
colormapString.erase(0, npos + 1);
125125

126126
npos = colormapString.find_first_of(',');
@@ -157,7 +157,7 @@ namespace tensorrt_yolox
157157
TrtYoloX::TrtYoloX(
158158
const std::string & model_path, const std::string & precision, const std::string & color_map_path,
159159
const int num_class, const float score_threshold, const float nms_threshold,
160-
tensorrt_common::BuildConfig build_config, const bool use_gpu_preprocess,
160+
tensorrt_common::BuildConfig build_config, const bool use_gpu_preprocess, bool publish_color_mask,
161161
std::string calibration_image_list_path, const double norm_factor,
162162
[[maybe_unused]] const std::string & cache_dir, const tensorrt_common::BatchConfig & batch_config,
163163
const size_t max_workspace_size)
@@ -167,7 +167,8 @@ TrtYoloX::TrtYoloX(
167167
norm_factor_ = norm_factor;
168168
batch_size_ = batch_config[2];
169169
multitask_ = 0;
170-
color_map_ = get_seg_colormap(color_map_path);
170+
sematic_color_map_ = get_seg_colormap(color_map_path);
171+
publish_color_mask_ = publish_color_mask;
171172
if (precision == "int8") {
172173
if (build_config.clip_value <= 0.0) {
173174
if (calibration_image_list_path.empty()) {
@@ -388,13 +389,14 @@ void TrtYoloX::initPreprocessBuffer(int width, int height)
388389
for (int m = 0; m < multitask_; m++) {
389390
const auto output_dims =
390391
trt_common_->getBindingDimensions(m + 2); // 0 : input, 1 : output for detections
391-
const float scale =
392-
std::min(output_dims.d[3] / float(width), output_dims.d[2] / float(height));
393-
int out_w = (int)(width * scale);
394-
int out_h = (int)(height * scale);
395-
// size_t out_elem_num = std::accumulate(
392+
const float scale = std::min(
393+
output_dims.d[3] / static_cast<float>(width),
394+
output_dims.d[2] / static_cast<float>(height));
395+
int out_w = static_cast<int>(width * scale);
396+
int out_h = static_cast<int>(height * scale);
397+
// size_t out_elem_num = std::accumulate(
396398
// output_dims.d + 1, output_dims.d + output_dims.nbDims, 1, std::multiplies<int>());
397-
// out_elem_num = out_elem_num * batch_size_;
399+
// out_elem_num = out_elem_num * batch_size_;
398400
size_t out_elem_num = out_w * out_h * batch_size_;
399401
argmax_out_elem_num += out_elem_num;
400402
}
@@ -468,8 +470,9 @@ void TrtYoloX::preprocessGpu(const std::vector<cv::Mat> & images)
468470
for (int m = 0; m < multitask_; m++) {
469471
const auto output_dims =
470472
trt_common_->getBindingDimensions(m + 2); // 0: input, 1: output for detections
471-
const float scale =
472-
std::min(output_dims.d[3] / float(image.cols), output_dims.d[2] / float(image.rows));
473+
const float scale = std::min(
474+
output_dims.d[3] / static_cast<float>(image.cols),
475+
output_dims.d[2] / static_cast<float>(image.rows));
473476
int out_w = static_cast<int>(image.cols * scale);
474477
int out_h = static_cast<int>(image.rows * scale);
475478
argmax_out_elem_num += out_w * out_h * batch_size;
@@ -545,8 +548,8 @@ void TrtYoloX::preprocess(const std::vector<cv::Mat> & images)
545548
}
546549

547550
bool TrtYoloX::doInference(
548-
const std::vector<cv::Mat> & images, ObjectArrays & objects, cv::Mat & mask,
549-
[[maybe_unused]] cv::Mat & color_mask)
551+
const std::vector<cv::Mat> & images, ObjectArrays & objects, std::vector<cv::Mat> & masks,
552+
[[maybe_unused]] std::vector<cv::Mat> & color_masks)
550553
{
551554
if (!trt_common_->isInitialized()) {
552555
return false;
@@ -559,7 +562,7 @@ bool TrtYoloX::doInference(
559562
}
560563

561564
if (needs_output_decode_) {
562-
return feedforwardAndDecode(images, objects, mask, color_mask);
565+
return feedforwardAndDecode(images, objects, masks, color_masks);
563566
} else {
564567
return feedforward(images, objects);
565568
}
@@ -799,8 +802,8 @@ void TrtYoloX::multiScalePreprocess(const cv::Mat & image, const std::vector<cv:
799802
bool TrtYoloX::doInferenceWithRoi(
800803
const std::vector<cv::Mat> & images, ObjectArrays & objects, const std::vector<cv::Rect> & rois)
801804
{
802-
cv::Mat mask;
803-
cv::Mat color_mask;
805+
std::vector<cv::Mat> masks;
806+
std::vector<cv::Mat> color_masks;
804807
if (!trt_common_->isInitialized()) {
805808
return false;
806809
}
@@ -811,7 +814,7 @@ bool TrtYoloX::doInferenceWithRoi(
811814
}
812815

813816
if (needs_output_decode_) {
814-
return feedforwardAndDecode(images, objects, mask, color_mask);
817+
return feedforwardAndDecode(images, objects, masks, color_masks);
815818
} else {
816819
return feedforward(images, objects);
817820
}
@@ -890,8 +893,8 @@ bool TrtYoloX::feedforward(const std::vector<cv::Mat> & images, ObjectArrays & o
890893
}
891894

892895
bool TrtYoloX::feedforwardAndDecode(
893-
const std::vector<cv::Mat> & images, ObjectArrays & objects, cv::Mat & out_mask,
894-
[[maybe_unused]] cv::Mat & color_mask)
896+
const std::vector<cv::Mat> & images, ObjectArrays & objects, std::vector<cv::Mat> & out_masks,
897+
[[maybe_unused]] std::vector<cv::Mat> & color_masks)
895898
{
896899
std::vector<void *> buffers = {input_d_.get(), out_prob_d_.get()};
897900
if (multitask_) {
@@ -914,26 +917,31 @@ bool TrtYoloX::feedforwardAndDecode(
914917

915918
for (size_t i = 0; i < batch_size; ++i) {
916919
auto image_size = images[i].size();
920+
auto & out_mask = out_masks[i];
921+
auto & color_mask = color_masks[i];
917922
float * batch_prob = out_prob_h_.get() + (i * out_elem_num_per_batch_);
918923
ObjectArray object_array;
919924
decodeOutputs(batch_prob, object_array, scales_[i], image_size);
925+
// add refine mask using object
920926
objects.emplace_back(object_array);
921927
if (multitask_) {
922928
segmentation_masks_.clear();
923929
float * segmentation_results =
924930
segmentation_out_prob_h_.get() + (i * segmentation_out_elem_num_per_batch_);
925931
size_t counter = 0;
926-
int batch = (int)(segmentation_out_elem_num_ / segmentation_out_elem_num_per_batch_);
932+
int batch =
933+
static_cast<int>(segmentation_out_elem_num_ / segmentation_out_elem_num_per_batch_);
927934
for (int m = 0; m < multitask_; m++) {
928935
const auto output_dims =
929936
trt_common_->getBindingDimensions(m + 2); // 0 : input, 1 : output for detections
930937
size_t out_elem_num = std::accumulate(
931938
output_dims.d + 1, output_dims.d + output_dims.nbDims, 1, std::multiplies<int>());
932939
out_elem_num = out_elem_num * batch;
933940
const float scale = std::min(
934-
output_dims.d[3] / float(image_size.width), output_dims.d[2] / float(image_size.height));
935-
int out_w = (int)(image_size.width * scale);
936-
int out_h = (int)(image_size.height * scale);
941+
output_dims.d[3] / static_cast<float>(image_size.width),
942+
output_dims.d[2] / static_cast<float>(image_size.height));
943+
int out_w = static_cast<int>(image_size.width * scale);
944+
int out_h = static_cast<int>(image_size.height * scale);
937945
cv::Mat mask;
938946
if (use_gpu_preprocess_) {
939947
float * d_segmentation_results =
@@ -945,8 +953,16 @@ bool TrtYoloX::feedforwardAndDecode(
945953
segmentation_masks_.push_back(mask);
946954
counter += out_elem_num;
947955
}
948-
out_mask = segmentation_masks_.at(0);
949-
color_mask = getColorizedMask(0, color_map_);
956+
} else {
957+
continue;
958+
}
959+
// Assume semantic segmentation is first task
960+
// This should remove when the segmentation accuracy is high
961+
out_mask = segmentation_masks_.at(0);
962+
963+
// publish color mask for visualization
964+
if (publish_color_mask_) {
965+
color_mask = getColorizedMask(0, sematic_color_map_);
950966
}
951967
}
952968
return true;

perception/tensorrt_yolox/src/tensorrt_yolox_node.cpp

+59-17
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ TrtYoloXNode::TrtYoloXNode(const rclcpp::NodeOptions & node_options)
9494
RCLCPP_ERROR(this->get_logger(), "Could not find label file");
9595
rclcpp::shutdown();
9696
}
97+
98+
is_roi_overlap_segment_ = declare_parameter<bool>("is_roi_overlap_segment");
99+
is_publish_color_mask_ = declare_parameter<bool>("is_publish_color_mask");
100+
overlap_roi_score_threshold_ = declare_parameter<float>("overlap_roi_score_threshold");
97101
replaceLabelMap();
98102

99103
tensorrt_common::BuildConfig build_config(
@@ -102,7 +106,7 @@ TrtYoloXNode::TrtYoloXNode(const rclcpp::NodeOptions & node_options)
102106

103107
trt_yolox_ = std::make_unique<tensorrt_yolox::TrtYoloX>(
104108
model_path, precision, color_map_path, label_map_.size(), score_threshold, nms_threshold,
105-
build_config, preprocess_on_gpu, calibration_image_list_path);
109+
build_config, preprocess_on_gpu, is_publish_color_mask_, calibration_image_list_path);
106110

107111
timer_ =
108112
rclcpp::create_timer(this, get_clock(), 100ms, std::bind(&TrtYoloXNode::onConnect, this));
@@ -150,13 +154,19 @@ void TrtYoloXNode::onImage(const sensor_msgs::msg::Image::ConstSharedPtr msg)
150154
const auto height = in_image_ptr->image.rows;
151155

152156
tensorrt_yolox::ObjectArrays objects;
153-
cv::Mat mask(cv::Size(height, width), CV_8UC1, cv::Scalar(0));
154-
cv::Mat color_mask(cv::Size(height, width), CV_8UC3, cv::Scalar(0, 0, 0));
157+
std::vector<cv::Mat> masks = {cv::Mat(cv::Size(height, width), CV_8UC1, cv::Scalar(0))};
158+
std::vector<cv::Mat> color_masks = {
159+
cv::Mat(cv::Size(height, width), CV_8UC3, cv::Scalar(0, 0, 0))};
155160

156-
if (!trt_yolox_->doInference({in_image_ptr->image}, objects, mask, color_mask)) {
161+
if (!trt_yolox_->doInference({in_image_ptr->image}, objects, masks, color_masks)) {
157162
RCLCPP_WARN(this->get_logger(), "Fail to inference");
158163
return;
159164
}
165+
auto & mask = masks.at(0);
166+
cv::resize(
167+
mask, mask, cv::Size(in_image_ptr->image.cols, in_image_ptr->image.rows), 0, 0,
168+
cv::INTER_NEAREST);
169+
160170
for (const auto & yolox_object : objects.at(0)) {
161171
tier4_perception_msgs::msg::DetectedObjectWithFeature object;
162172
object.feature.roi.x_offset = yolox_object.x_offset;
@@ -176,29 +186,32 @@ void TrtYoloXNode::onImage(const sensor_msgs::msg::Image::ConstSharedPtr msg)
176186
cv::rectangle(
177187
in_image_ptr->image, cv::Point(left, top), cv::Point(right, bottom), cv::Scalar(0, 0, 255), 3,
178188
8, 0);
189+
// Refine mask: replacing segmentation mask by roi class
190+
if (is_roi_overlap_segment_) {
191+
overlapSegmentByRoi(yolox_object, mask);
192+
}
179193
}
180-
cv::resize(
181-
mask, mask, cv::Size(in_image_ptr->image.cols, in_image_ptr->image.rows), 0, 0,
182-
cv::INTER_NEAREST);
183194
sensor_msgs::msg::Image::SharedPtr out_mask_msg =
184195
cv_bridge::CvImage(std_msgs::msg::Header(), sensor_msgs::image_encodings::MONO8, mask)
185196
.toImageMsg();
186197
out_mask_msg->header = msg->header;
187198
mask_pub_.publish(out_mask_msg);
188199

189-
cv::resize(
190-
color_mask, color_mask, cv::Size(in_image_ptr->image.cols, in_image_ptr->image.rows), 0, 0,
191-
cv::INTER_NEAREST);
192-
sensor_msgs::msg::Image::SharedPtr output_color_mask_msg =
193-
cv_bridge::CvImage(std_msgs::msg::Header(), sensor_msgs::image_encodings::BGR8, color_mask)
194-
.toImageMsg();
195-
output_color_mask_msg->header = msg->header;
196-
color_mask_pub_.publish(output_color_mask_msg);
197-
198200
image_pub_.publish(in_image_ptr->toImageMsg());
199-
200201
out_objects.header = msg->header;
201202
objects_pub_->publish(out_objects);
203+
204+
if (is_publish_color_mask_) {
205+
auto & color_mask = color_masks.at(0);
206+
cv::resize(
207+
color_mask, color_mask, cv::Size(in_image_ptr->image.cols, in_image_ptr->image.rows), 0, 0,
208+
cv::INTER_NEAREST);
209+
sensor_msgs::msg::Image::SharedPtr output_color_mask_msg =
210+
cv_bridge::CvImage(std_msgs::msg::Header(), sensor_msgs::image_encodings::BGR8, color_mask)
211+
.toImageMsg();
212+
output_color_mask_msg->header = msg->header;
213+
color_mask_pub_.publish(output_color_mask_msg);
214+
}
202215
}
203216

204217
bool TrtYoloXNode::readLabelFile(const std::string & label_path)
@@ -235,6 +248,35 @@ void TrtYoloXNode::replaceLabelMap()
235248
}
236249
}
237250

251+
int TrtYoloXNode::mapRoiLabel2SegLabel(const int32_t roi_label_index)
252+
{
253+
auto & roi_label = label_map_[roi_label_index];
254+
if (roi_label == "CAR" || roi_label == "BUS" || roi_label == "TRUCK") {
255+
return static_cast<int>(roi_label_index + 11);
256+
}
257+
if (roi_label == "PEDESTRIAN") {
258+
return 11; // person index in segment_color_map
259+
}
260+
if (roi_label == "MOTORCYCLE") {
261+
return 17; // motocycle index in segment_color_map
262+
}
263+
if (roi_label == "BICYCLE") {
264+
return 18; // bicycle index in segment_color_map
265+
}
266+
return -1;
267+
}
268+
269+
void TrtYoloXNode::overlapSegmentByRoi(const tensorrt_yolox::Object & roi_object, cv::Mat & mask)
270+
{
271+
if (roi_object.score < overlap_roi_score_threshold_) return;
272+
cv::Mat submat = mask.colRange(roi_object.x_offset, roi_object.width)
273+
.rowRange(roi_object.y_offset, roi_object.height);
274+
int seg_class_index = mapRoiLabel2SegLabel(roi_object.type);
275+
if (seg_class_index < 0) return;
276+
cv::Mat replace_roi(cv::Size(), mask.type(), seg_class_index);
277+
replace_roi.copyTo(submat);
278+
}
279+
238280
} // namespace tensorrt_yolox
239281

240282
#include "rclcpp_components/register_node_macro.hpp"

0 commit comments

Comments
 (0)