Skip to content

Commit a15ac04

Browse files
committed
fix: skip mask size as yolox output
Signed-off-by: badai-nguyen <dai.nguyen@tier4.jp>
1 parent 8832a3a commit a15ac04

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

perception/tensorrt_yolox/include/tensorrt_yolox/tensorrt_yolox_node.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ class TrtYoloXNode : public rclcpp::Node
7676
void onImage(const sensor_msgs::msg::Image::ConstSharedPtr msg);
7777
bool readLabelFile(const std::string & label_path);
7878
void replaceLabelMap();
79-
void overlapSegmentByRoi(const tensorrt_yolox::Object & object, cv::Mat & mask);
79+
void overlapSegmentByRoi(
80+
const tensorrt_yolox::Object & object, cv::Mat & mask, const int width, const int height);
8081
int mapRoiLabel2SegLabel(const int32_t roi_label_index);
8182
image_transport::Publisher image_pub_;
8283
image_transport::Publisher mask_pub_;

perception/tensorrt_yolox/src/tensorrt_yolox_node.cpp

+16-12
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,6 @@ void TrtYoloXNode::onImage(const sensor_msgs::msg::Image::ConstSharedPtr msg)
190190
return;
191191
}
192192
auto & mask = masks.at(0);
193-
// TODO(badai-nguyen): change to postprocess on gpu option
194-
cv::resize(
195-
mask, mask, cv::Size(in_image_ptr->image.cols, in_image_ptr->image.rows), 0, 0,
196-
cv::INTER_NEAREST);
197193

198194
for (const auto & yolox_object : objects.at(0)) {
199195
tier4_perception_msgs::msg::DetectedObjectWithFeature object;
@@ -217,7 +213,7 @@ void TrtYoloXNode::onImage(const sensor_msgs::msg::Image::ConstSharedPtr msg)
217213
// Refine mask: replacing segmentation mask by roi class
218214
// This should remove when the segmentation accuracy is high
219215
if (is_roi_overlap_segment_ && trt_yolox_->getMultitaskNum() > 0) {
220-
overlapSegmentByRoi(yolox_object, mask);
216+
overlapSegmentByRoi(yolox_object, mask, width, height);
221217
}
222218
}
223219
// TODO(badai-nguyen): consider to change to 4bits data transfer
@@ -249,8 +245,7 @@ void TrtYoloXNode::onImage(const sensor_msgs::msg::Image::ConstSharedPtr msg)
249245
}
250246

251247
if (is_publish_color_mask_ && trt_yolox_->getMultitaskNum() > 0) {
252-
cv::Mat color_mask =
253-
cv::Mat::zeros(in_image_ptr->image.rows, in_image_ptr->image.cols, CV_8UC3);
248+
cv::Mat color_mask = cv::Mat::zeros(mask.rows, mask.cols, CV_8UC3);
254249
trt_yolox_->getColorizedMask(trt_yolox_->getColorMap(), mask, color_mask);
255250
sensor_msgs::msg::Image::SharedPtr output_color_mask_msg =
256251
cv_bridge::CvImage(std_msgs::msg::Header(), sensor_msgs::image_encodings::BGR8, color_mask)
@@ -304,16 +299,25 @@ int TrtYoloXNode::mapRoiLabel2SegLabel(const int32_t roi_label_index)
304299
return -1;
305300
}
306301

307-
void TrtYoloXNode::overlapSegmentByRoi(const tensorrt_yolox::Object & roi_object, cv::Mat & mask)
302+
void TrtYoloXNode::overlapSegmentByRoi(
303+
const tensorrt_yolox::Object & roi_object, cv::Mat & mask, const int orig_width,
304+
const int orig_height)
308305
{
309306
if (roi_object.score < overlap_roi_score_threshold_) return;
310307
int seg_class_index = mapRoiLabel2SegLabel(roi_object.type);
311308
if (seg_class_index < 0) return;
309+
310+
const float scale_x = static_cast<float>(mask.cols) / static_cast<float>(orig_width);
311+
const float scale_y = static_cast<float>(mask.rows) / static_cast<float>(orig_height);
312+
const int roi_width = static_cast<int>(roi_object.width * scale_x);
313+
const int roi_height = static_cast<int>(roi_object.height * scale_y);
314+
const int roi_x_offset = static_cast<int>(roi_object.x_offset * scale_x);
315+
const int roi_y_offset = static_cast<int>(roi_object.y_offset * scale_y);
316+
312317
cv::Mat replace_roi(
313-
cv::Size(roi_object.width, roi_object.height), mask.type(),
314-
static_cast<uint8_t>(seg_class_index));
315-
replace_roi.copyTo(mask.colRange(roi_object.x_offset, roi_object.x_offset + roi_object.width)
316-
.rowRange(roi_object.y_offset, roi_object.y_offset + roi_object.height));
318+
cv::Size(roi_width, roi_height), mask.type(), static_cast<uint8_t>(seg_class_index));
319+
replace_roi.copyTo(mask.colRange(roi_x_offset, roi_x_offset + roi_width)
320+
.rowRange(roi_y_offset, roi_y_offset + roi_height));
317321
}
318322

319323
} // namespace tensorrt_yolox

0 commit comments

Comments
 (0)