Skip to content

Commit f8e6bd0

Browse files
committed
fix: roi overlay segment
Signed-off-by: badai-nguyen <dai.nguyen@tier4.jp>
1 parent aede063 commit f8e6bd0

File tree

3 files changed

+33
-33
lines changed

3 files changed

+33
-33
lines changed

perception/tensorrt_yolox/include/tensorrt_yolox/tensorrt_yolox.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,10 @@ class TrtYoloX
153153
* @param[in] index multitask index
154154
* @param[in] colormap colormap for masks
155155
*/
156-
cv::Mat getColorizedMask(int index, std::vector<Colormap> & colormap);
156+
void getColorizedMask(
157+
const std::vector<tensorrt_yolox::Colormap> & colormap, const cv::Mat & mask,
158+
cv::Mat & colorized_mask);
159+
inline std::vector<Colormap> getColorMap() { return sematic_color_map_; }
157160

158161
private:
159162
/**

perception/tensorrt_yolox/src/tensorrt_yolox.cpp

+3-12
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,6 @@ bool TrtYoloX::feedforwardAndDecode(
918918
for (size_t i = 0; i < batch_size; ++i) {
919919
auto image_size = images[i].size();
920920
auto & out_mask = out_masks[i];
921-
auto & color_mask = color_masks[i];
922921
float * batch_prob = out_prob_h_.get() + (i * out_elem_num_per_batch_);
923922
ObjectArray object_array;
924923
decodeOutputs(batch_prob, object_array, scales_[i], image_size);
@@ -957,13 +956,7 @@ bool TrtYoloX::feedforwardAndDecode(
957956
continue;
958957
}
959958
// Assume semantic segmentation is first task
960-
// This should remove when the segmentation accuracy is high
961959
out_mask = segmentation_masks_.at(0);
962-
963-
// publish color mask for visualization
964-
if (publish_color_mask_) {
965-
color_mask = getColorizedMask(0, sematic_color_map_);
966-
}
967960
}
968961
return true;
969962
}
@@ -1283,13 +1276,12 @@ int TrtYoloX::getMultitaskNum(void)
12831276
return multitask_;
12841277
}
12851278

1286-
cv::Mat TrtYoloX::getColorizedMask(int index, std::vector<Colormap> & colormap)
1279+
void TrtYoloX::getColorizedMask(
1280+
const std::vector<tensorrt_yolox::Colormap> & colormap, const cv::Mat & mask, cv::Mat & cmask)
12871281
{
1288-
cv::Mat mask;
1289-
mask = segmentation_masks_[index];
12901282
int width = mask.cols;
12911283
int height = mask.rows;
1292-
cv::Mat cmask = cv::Mat::zeros(height, width, CV_8UC3);
1284+
// TODO: check size of mask and cmask
12931285
for (int y = 0; y < height; y++) {
12941286
for (int x = 0; x < width; x++) {
12951287
unsigned char id = mask.at<unsigned char>(y, x);
@@ -1298,7 +1290,6 @@ cv::Mat TrtYoloX::getColorizedMask(int index, std::vector<Colormap> & colormap)
12981290
cmask.at<cv::Vec3b>(y, x)[2] = colormap[id].color[0];
12991291
}
13001292
}
1301-
return cmask;
13021293
}
13031294

13041295
} // namespace tensorrt_yolox

perception/tensorrt_yolox/src/tensorrt_yolox_node.cpp

+26-20
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ void TrtYoloXNode::onImage(const sensor_msgs::msg::Image::ConstSharedPtr msg)
187187
in_image_ptr->image, cv::Point(left, top), cv::Point(right, bottom), cv::Scalar(0, 0, 255), 3,
188188
8, 0);
189189
// Refine mask: replacing segmentation mask by roi class
190+
// This should remove when the segmentation accuracy is high
190191
if (is_roi_overlap_segment_) {
191192
overlapSegmentByRoi(yolox_object, mask);
192193
}
@@ -202,10 +203,9 @@ void TrtYoloXNode::onImage(const sensor_msgs::msg::Image::ConstSharedPtr msg)
202203
objects_pub_->publish(out_objects);
203204

204205
if (is_publish_color_mask_) {
205-
auto & color_mask = color_masks.at(0);
206-
cv::resize(
207-
color_mask, color_mask, cv::Size(in_image_ptr->image.cols, in_image_ptr->image.rows), 0, 0,
208-
cv::INTER_NEAREST);
206+
cv::Mat color_mask =
207+
cv::Mat::zeros(in_image_ptr->image.rows, in_image_ptr->image.cols, CV_8UC3);
208+
trt_yolox_->getColorizedMask(trt_yolox_->getColorMap(), mask, color_mask);
209209
sensor_msgs::msg::Image::SharedPtr output_color_mask_msg =
210210
cv_bridge::CvImage(std_msgs::msg::Header(), sensor_msgs::image_encodings::BGR8, color_mask)
211211
.toImageMsg();
@@ -250,31 +250,37 @@ void TrtYoloXNode::replaceLabelMap()
250250

251251
int TrtYoloXNode::mapRoiLabel2SegLabel(const int32_t roi_label_index)
252252
{
253-
auto & roi_label = label_map_[roi_label_index];
254-
if (roi_label == "CAR" || roi_label == "BUS" || roi_label == "TRUCK") {
255-
return static_cast<int>(roi_label_index + 11);
256-
}
257-
if (roi_label == "PEDESTRIAN") {
258-
return 11; // person index in segment_color_map
259-
}
260-
if (roi_label == "MOTORCYCLE") {
261-
return 17; // motocycle index in segment_color_map
262-
}
263-
if (roi_label == "BICYCLE") {
264-
return 18; // bicycle index in segment_color_map
253+
switch (roi_label_index) {
254+
case 0:
255+
return 5;
256+
case 1:
257+
return 13;
258+
case 2:
259+
return 14;
260+
case 3:
261+
return 15;
262+
case 4:
263+
return 18;
264+
case 5:
265+
return 17;
266+
case 6:
267+
return 11;
268+
default:
269+
return -1;
265270
}
266271
return -1;
267272
}
268273

269274
void TrtYoloXNode::overlapSegmentByRoi(const tensorrt_yolox::Object & roi_object, cv::Mat & mask)
270275
{
271276
if (roi_object.score < overlap_roi_score_threshold_) return;
272-
cv::Mat submat = mask.colRange(roi_object.x_offset, roi_object.width)
273-
.rowRange(roi_object.y_offset, roi_object.height);
274277
int seg_class_index = mapRoiLabel2SegLabel(roi_object.type);
275278
if (seg_class_index < 0) return;
276-
cv::Mat replace_roi(cv::Size(), mask.type(), seg_class_index);
277-
replace_roi.copyTo(submat);
279+
cv::Mat replace_roi(
280+
cv::Size(roi_object.width, roi_object.height), mask.type(),
281+
static_cast<uint8_t>(seg_class_index));
282+
replace_roi.copyTo(mask.colRange(roi_object.x_offset, roi_object.x_offset + roi_object.width)
283+
.rowRange(roi_object.y_offset, roi_object.y_offset + roi_object.height));
278284
}
279285

280286
} // namespace tensorrt_yolox

0 commit comments

Comments
 (0)