@@ -187,6 +187,7 @@ void TrtYoloXNode::onImage(const sensor_msgs::msg::Image::ConstSharedPtr msg)
187
187
in_image_ptr->image , cv::Point (left, top), cv::Point (right, bottom), cv::Scalar (0 , 0 , 255 ), 3 ,
188
188
8 , 0 );
189
189
// Refine mask: replacing segmentation mask by roi class
190
+ // This should remove when the segmentation accuracy is high
190
191
if (is_roi_overlap_segment_) {
191
192
overlapSegmentByRoi (yolox_object, mask);
192
193
}
@@ -202,10 +203,9 @@ void TrtYoloXNode::onImage(const sensor_msgs::msg::Image::ConstSharedPtr msg)
202
203
objects_pub_->publish (out_objects);
203
204
204
205
if (is_publish_color_mask_) {
205
- auto & color_mask = color_masks.at (0 );
206
- cv::resize (
207
- color_mask, color_mask, cv::Size (in_image_ptr->image .cols , in_image_ptr->image .rows ), 0 , 0 ,
208
- cv::INTER_NEAREST);
206
+ cv::Mat color_mask =
207
+ cv::Mat::zeros (in_image_ptr->image .rows , in_image_ptr->image .cols , CV_8UC3);
208
+ trt_yolox_->getColorizedMask (trt_yolox_->getColorMap (), mask, color_mask);
209
209
sensor_msgs::msg::Image::SharedPtr output_color_mask_msg =
210
210
cv_bridge::CvImage (std_msgs::msg::Header (), sensor_msgs::image_encodings::BGR8, color_mask)
211
211
.toImageMsg ();
@@ -250,31 +250,37 @@ void TrtYoloXNode::replaceLabelMap()
250
250
251
251
int TrtYoloXNode::mapRoiLabel2SegLabel (const int32_t roi_label_index)
252
252
{
253
- auto & roi_label = label_map_[roi_label_index];
254
- if (roi_label == " CAR" || roi_label == " BUS" || roi_label == " TRUCK" ) {
255
- return static_cast <int >(roi_label_index + 11 );
256
- }
257
- if (roi_label == " PEDESTRIAN" ) {
258
- return 11 ; // person index in segment_color_map
259
- }
260
- if (roi_label == " MOTORCYCLE" ) {
261
- return 17 ; // motocycle index in segment_color_map
262
- }
263
- if (roi_label == " BICYCLE" ) {
264
- return 18 ; // bicycle index in segment_color_map
253
+ switch (roi_label_index) {
254
+ case 0 :
255
+ return 5 ;
256
+ case 1 :
257
+ return 13 ;
258
+ case 2 :
259
+ return 14 ;
260
+ case 3 :
261
+ return 15 ;
262
+ case 4 :
263
+ return 18 ;
264
+ case 5 :
265
+ return 17 ;
266
+ case 6 :
267
+ return 11 ;
268
+ default :
269
+ return -1 ;
265
270
}
266
271
return -1 ;
267
272
}
268
273
269
274
void TrtYoloXNode::overlapSegmentByRoi (const tensorrt_yolox::Object & roi_object, cv::Mat & mask)
270
275
{
271
276
if (roi_object.score < overlap_roi_score_threshold_) return ;
272
- cv::Mat submat = mask.colRange (roi_object.x_offset , roi_object.width )
273
- .rowRange (roi_object.y_offset , roi_object.height );
274
277
int seg_class_index = mapRoiLabel2SegLabel (roi_object.type );
275
278
if (seg_class_index < 0 ) return ;
276
- cv::Mat replace_roi (cv::Size (), mask.type (), seg_class_index);
277
- replace_roi.copyTo (submat);
279
+ cv::Mat replace_roi (
280
+ cv::Size (roi_object.width , roi_object.height ), mask.type (),
281
+ static_cast <uint8_t >(seg_class_index));
282
+ replace_roi.copyTo (mask.colRange (roi_object.x_offset , roi_object.x_offset + roi_object.width )
283
+ .rowRange (roi_object.y_offset , roi_object.y_offset + roi_object.height ));
278
284
}
279
285
280
286
} // namespace tensorrt_yolox
0 commit comments