Skip to content

Commit 5105e83

Browse files
test(lidar_centerpoint): add test (#7029)
* test(lidar_centerpoint): add test Signed-off-by: kminoda <koji.minoda@tier4.jp> * update test Signed-off-by: kminoda <koji.minoda@tier4.jp> * update license Signed-off-by: kminoda <koji.minoda@tier4.jp> * style(pre-commit): autofix * fix namespace Signed-off-by: kminoda <koji.minoda@tier4.jp> * style(pre-commit): autofix --------- Signed-off-by: kminoda <koji.minoda@tier4.jp> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d836237 commit 5105e83

File tree

4 files changed

+368
-0
lines changed

4 files changed

+368
-0
lines changed

perception/lidar_centerpoint/CMakeLists.txt

+14
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,20 @@ if(TRT_AVAIL AND CUDA_AVAIL AND CUDNN_AVAIL)
147147
TARGETS centerpoint_cuda_lib
148148
DESTINATION lib
149149
)
150+
151+
if(BUILD_TESTING)
152+
find_package(ament_cmake_gtest REQUIRED)
153+
ament_auto_add_gtest(test_detection_class_remapper
154+
test/test_detection_class_remapper.cpp
155+
)
156+
ament_auto_add_gtest(test_ros_utils
157+
test/test_ros_utils.cpp
158+
)
159+
ament_auto_add_gtest(test_nms
160+
test/test_nms.cpp
161+
)
162+
endif()
163+
150164
else()
151165
find_package(ament_cmake_auto REQUIRED)
152166
ament_auto_find_build_dependencies()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Copyright 2024 TIER IV, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <lidar_centerpoint/detection_class_remapper.hpp>
16+
17+
#include <gtest/gtest.h>
18+
19+
TEST(DetectionClassRemapperTest, MapClasses)
20+
{
21+
centerpoint::DetectionClassRemapper remapper;
22+
23+
// Set up the parameters for the remapper
24+
// Labels: CAR, TRUCK, TRAILER
25+
std::vector<int64_t> allow_remapping_by_area_matrix = {
26+
0, 0, 0, // CAR cannot be remapped
27+
0, 0, 1, // TRUCK can be remapped to TRAILER
28+
0, 1, 0 // TRAILER can be remapped to TRUCK
29+
};
30+
std::vector<double> min_area_matrix = {0.0, 0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0};
31+
std::vector<double> max_area_matrix = {0.0, 0.0, 0.0, 0.0, 0.0, 999.0, 0.0, 10.0, 0.0};
32+
33+
remapper.setParameters(allow_remapping_by_area_matrix, min_area_matrix, max_area_matrix);
34+
35+
// Create a DetectedObjects message with some objects
36+
autoware_auto_perception_msgs::msg::DetectedObjects msg;
37+
38+
// CAR with area 4.0, which is out of the range for remapping
39+
autoware_auto_perception_msgs::msg::DetectedObject obj1;
40+
autoware_auto_perception_msgs::msg::ObjectClassification obj1_classification;
41+
obj1.shape.dimensions.x = 2.0;
42+
obj1.shape.dimensions.y = 2.0;
43+
obj1_classification.label = 0;
44+
obj1_classification.probability = 1.0;
45+
obj1.classification = {obj1_classification};
46+
msg.objects.push_back(obj1);
47+
48+
// TRUCK with area 16.0, which is in the range for remapping to TRAILER
49+
autoware_auto_perception_msgs::msg::DetectedObject obj2;
50+
autoware_auto_perception_msgs::msg::ObjectClassification obj2_classification;
51+
obj2.shape.dimensions.x = 8.0;
52+
obj2.shape.dimensions.y = 2.0;
53+
obj2_classification.label = 1;
54+
obj2_classification.probability = 1.0;
55+
obj2.classification = {obj2_classification};
56+
msg.objects.push_back(obj2);
57+
58+
// TRAILER with area 9.0, which is in the range for remapping to TRUCK
59+
autoware_auto_perception_msgs::msg::DetectedObject obj3;
60+
autoware_auto_perception_msgs::msg::ObjectClassification obj3_classification;
61+
obj3.shape.dimensions.x = 3.0;
62+
obj3.shape.dimensions.y = 3.0;
63+
obj3_classification.label = 2;
64+
obj3_classification.probability = 1.0;
65+
obj3.classification = {obj3_classification};
66+
msg.objects.push_back(obj3);
67+
68+
// TRAILER with area 12.0, which is out of the range for remapping
69+
autoware_auto_perception_msgs::msg::DetectedObject obj4;
70+
autoware_auto_perception_msgs::msg::ObjectClassification obj4_classification;
71+
obj4.shape.dimensions.x = 4.0;
72+
obj4.shape.dimensions.y = 3.0;
73+
obj4_classification.label = 2;
74+
obj4_classification.probability = 1.0;
75+
obj4.classification = {obj4_classification};
76+
msg.objects.push_back(obj4);
77+
78+
// Call the mapClasses function
79+
remapper.mapClasses(msg);
80+
81+
// Check the remapped labels
82+
EXPECT_EQ(msg.objects[0].classification[0].label, 0);
83+
EXPECT_EQ(msg.objects[1].classification[0].label, 2);
84+
EXPECT_EQ(msg.objects[2].classification[0].label, 1);
85+
EXPECT_EQ(msg.objects[3].classification[0].label, 2);
86+
}
87+
88+
int main(int argc, char ** argv)
89+
{
90+
testing::InitGoogleTest(&argc, argv);
91+
return RUN_ALL_TESTS();
92+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Copyright 2024 TIER IV, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lidar_centerpoint/postprocess/non_maximum_suppression.hpp"
16+
17+
#include <gtest/gtest.h>
18+
19+
TEST(NonMaximumSuppressionTest, Apply)
20+
{
21+
centerpoint::NonMaximumSuppression nms;
22+
centerpoint::NMSParams params;
23+
params.search_distance_2d_ = 1.0;
24+
params.iou_threshold_ = 0.2;
25+
params.nms_type_ = centerpoint::NMS_TYPE::IoU_BEV;
26+
params.target_class_names_ = {"CAR"};
27+
nms.setParameters(params);
28+
29+
std::vector<centerpoint::DetectedObject> input_objects(4);
30+
31+
// Object 1
32+
autoware_auto_perception_msgs::msg::ObjectClassification obj1_classification;
33+
obj1_classification.label = 0; // Assuming "car" has label 0
34+
obj1_classification.probability = 1.0;
35+
input_objects[0].classification = {obj1_classification}; // Assuming "car" has label 0
36+
input_objects[0].kinematics.pose_with_covariance.pose.position.x = 0.0;
37+
input_objects[0].kinematics.pose_with_covariance.pose.position.y = 0.0;
38+
input_objects[0].kinematics.pose_with_covariance.pose.orientation.x = 0.0;
39+
input_objects[0].kinematics.pose_with_covariance.pose.orientation.y = 0.0;
40+
input_objects[0].kinematics.pose_with_covariance.pose.orientation.z = 0.0;
41+
input_objects[0].kinematics.pose_with_covariance.pose.orientation.w = 1.0;
42+
input_objects[0].shape.type = autoware_auto_perception_msgs::msg::Shape::BOUNDING_BOX;
43+
input_objects[0].shape.dimensions.x = 4.0;
44+
input_objects[0].shape.dimensions.y = 2.0;
45+
46+
// Object 2 (overlaps with Object 1)
47+
autoware_auto_perception_msgs::msg::ObjectClassification obj2_classification;
48+
obj2_classification.label = 0; // Assuming "car" has label 0
49+
obj2_classification.probability = 1.0;
50+
input_objects[1].classification = {obj2_classification}; // Assuming "car" has label 0
51+
input_objects[1].kinematics.pose_with_covariance.pose.position.x = 0.5;
52+
input_objects[1].kinematics.pose_with_covariance.pose.position.y = 0.5;
53+
input_objects[1].kinematics.pose_with_covariance.pose.orientation.x = 0.0;
54+
input_objects[1].kinematics.pose_with_covariance.pose.orientation.y = 0.0;
55+
input_objects[1].kinematics.pose_with_covariance.pose.orientation.z = 0.0;
56+
input_objects[1].kinematics.pose_with_covariance.pose.orientation.w = 1.0;
57+
input_objects[1].shape.type = autoware_auto_perception_msgs::msg::Shape::BOUNDING_BOX;
58+
input_objects[1].shape.dimensions.x = 4.0;
59+
input_objects[1].shape.dimensions.y = 2.0;
60+
61+
// Object 3
62+
autoware_auto_perception_msgs::msg::ObjectClassification obj3_classification;
63+
obj3_classification.label = 0; // Assuming "car" has label 0
64+
obj3_classification.probability = 1.0;
65+
input_objects[2].classification = {obj3_classification}; // Assuming "car" has label 0
66+
input_objects[2].kinematics.pose_with_covariance.pose.position.x = 5.0;
67+
input_objects[2].kinematics.pose_with_covariance.pose.position.y = 5.0;
68+
input_objects[2].kinematics.pose_with_covariance.pose.orientation.x = 0.0;
69+
input_objects[2].kinematics.pose_with_covariance.pose.orientation.y = 0.0;
70+
input_objects[2].kinematics.pose_with_covariance.pose.orientation.z = 0.0;
71+
input_objects[2].kinematics.pose_with_covariance.pose.orientation.w = 1.0;
72+
input_objects[2].shape.type = autoware_auto_perception_msgs::msg::Shape::BOUNDING_BOX;
73+
input_objects[2].shape.dimensions.x = 4.0;
74+
input_objects[2].shape.dimensions.y = 2.0;
75+
76+
// Object 4 (different class)
77+
autoware_auto_perception_msgs::msg::ObjectClassification obj4_classification;
78+
obj4_classification.label = 1; // Assuming "pedestrian" has label 1
79+
obj4_classification.probability = 1.0;
80+
input_objects[3].classification = {obj4_classification}; // Assuming "pedestrian" has label 1
81+
input_objects[3].kinematics.pose_with_covariance.pose.position.x = 0.0;
82+
input_objects[3].kinematics.pose_with_covariance.pose.position.y = 0.0;
83+
input_objects[3].kinematics.pose_with_covariance.pose.orientation.x = 0.0;
84+
input_objects[3].kinematics.pose_with_covariance.pose.orientation.y = 0.0;
85+
input_objects[3].kinematics.pose_with_covariance.pose.orientation.z = 0.0;
86+
input_objects[3].kinematics.pose_with_covariance.pose.orientation.w = 1.0;
87+
input_objects[3].shape.type = autoware_auto_perception_msgs::msg::Shape::BOUNDING_BOX;
88+
input_objects[3].shape.dimensions.x = 0.5;
89+
input_objects[3].shape.dimensions.y = 0.5;
90+
91+
std::vector<centerpoint::DetectedObject> output_objects = nms.apply(input_objects);
92+
93+
// Assert the expected number of output objects
94+
EXPECT_EQ(output_objects.size(), 3);
95+
96+
// Assert that input_objects[2] is included in the output_objects
97+
bool is_input_object_2_included = false;
98+
for (const auto & output_object : output_objects) {
99+
if (output_object == input_objects[2]) {
100+
is_input_object_2_included = true;
101+
break;
102+
}
103+
}
104+
EXPECT_TRUE(is_input_object_2_included);
105+
106+
// Assert that input_objects[3] is included in the output_objects
107+
bool is_input_object_3_included = false;
108+
for (const auto & output_object : output_objects) {
109+
if (output_object == input_objects[3]) {
110+
is_input_object_3_included = true;
111+
break;
112+
}
113+
}
114+
EXPECT_TRUE(is_input_object_3_included);
115+
}
116+
117+
int main(int argc, char ** argv)
118+
{
119+
testing::InitGoogleTest(&argc, argv);
120+
return RUN_ALL_TESTS();
121+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// Copyright 2024 TIER IV, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lidar_centerpoint/ros_utils.hpp"
16+
17+
#include <gtest/gtest.h>
18+
19+
TEST(TestSuite, box3DToDetectedObject)
20+
{
21+
std::vector<std::string> class_names = {"CAR", "TRUCK", "BUS", "TRAILER",
22+
"BICYCLE", "MOTORBIKE", "PEDESTRIAN"};
23+
24+
// Test case 1: Test with valid label, has_twist=true, has_variance=true
25+
{
26+
centerpoint::Box3D box3d;
27+
box3d.score = 0.8f;
28+
box3d.label = 0; // CAR
29+
box3d.x = 1.0;
30+
box3d.y = 2.0;
31+
box3d.z = 3.0;
32+
box3d.yaw = 0.5;
33+
box3d.length = 4.0;
34+
box3d.width = 2.0;
35+
box3d.height = 1.5;
36+
box3d.vel_x = 1.0;
37+
box3d.vel_y = 0.5;
38+
box3d.x_variance = 0.1;
39+
box3d.y_variance = 0.2;
40+
box3d.z_variance = 0.3;
41+
box3d.yaw_variance = 0.4;
42+
box3d.vel_x_variance = 0.5;
43+
box3d.vel_y_variance = 0.6;
44+
45+
autoware_auto_perception_msgs::msg::DetectedObject obj;
46+
centerpoint::box3DToDetectedObject(box3d, class_names, true, true, obj);
47+
48+
EXPECT_FLOAT_EQ(obj.existence_probability, 0.8f);
49+
EXPECT_EQ(
50+
obj.classification[0].label, autoware_auto_perception_msgs::msg::ObjectClassification::CAR);
51+
EXPECT_FLOAT_EQ(obj.kinematics.pose_with_covariance.pose.position.x, 1.0);
52+
EXPECT_FLOAT_EQ(obj.kinematics.pose_with_covariance.pose.position.y, 2.0);
53+
EXPECT_FLOAT_EQ(obj.kinematics.pose_with_covariance.pose.position.z, 3.0);
54+
EXPECT_FLOAT_EQ(obj.shape.dimensions.x, 4.0);
55+
EXPECT_FLOAT_EQ(obj.shape.dimensions.y, 2.0);
56+
EXPECT_FLOAT_EQ(obj.shape.dimensions.z, 1.5);
57+
EXPECT_TRUE(obj.kinematics.has_position_covariance);
58+
EXPECT_TRUE(obj.kinematics.has_twist);
59+
EXPECT_TRUE(obj.kinematics.has_twist_covariance);
60+
}
61+
62+
// Test case 2: Test with invalid label, has_twist=false, has_variance=false
63+
{
64+
centerpoint::Box3D box3d;
65+
box3d.score = 0.5f;
66+
box3d.label = 10; // Invalid
67+
68+
autoware_auto_perception_msgs::msg::DetectedObject obj;
69+
centerpoint::box3DToDetectedObject(box3d, class_names, false, false, obj);
70+
71+
EXPECT_FLOAT_EQ(obj.existence_probability, 0.5f);
72+
EXPECT_EQ(
73+
obj.classification[0].label,
74+
autoware_auto_perception_msgs::msg::ObjectClassification::UNKNOWN);
75+
EXPECT_FALSE(obj.kinematics.has_position_covariance);
76+
EXPECT_FALSE(obj.kinematics.has_twist);
77+
EXPECT_FALSE(obj.kinematics.has_twist_covariance);
78+
}
79+
}
80+
81+
TEST(TestSuite, getSemanticType)
82+
{
83+
EXPECT_EQ(
84+
centerpoint::getSemanticType("CAR"),
85+
autoware_auto_perception_msgs::msg::ObjectClassification::CAR);
86+
EXPECT_EQ(
87+
centerpoint::getSemanticType("TRUCK"),
88+
autoware_auto_perception_msgs::msg::ObjectClassification::TRUCK);
89+
EXPECT_EQ(
90+
centerpoint::getSemanticType("BUS"),
91+
autoware_auto_perception_msgs::msg::ObjectClassification::BUS);
92+
EXPECT_EQ(
93+
centerpoint::getSemanticType("TRAILER"),
94+
autoware_auto_perception_msgs::msg::ObjectClassification::TRAILER);
95+
EXPECT_EQ(
96+
centerpoint::getSemanticType("BICYCLE"),
97+
autoware_auto_perception_msgs::msg::ObjectClassification::BICYCLE);
98+
EXPECT_EQ(
99+
centerpoint::getSemanticType("MOTORBIKE"),
100+
autoware_auto_perception_msgs::msg::ObjectClassification::MOTORCYCLE);
101+
EXPECT_EQ(
102+
centerpoint::getSemanticType("PEDESTRIAN"),
103+
autoware_auto_perception_msgs::msg::ObjectClassification::PEDESTRIAN);
104+
EXPECT_EQ(
105+
centerpoint::getSemanticType("UNKNOWN"),
106+
autoware_auto_perception_msgs::msg::ObjectClassification::UNKNOWN);
107+
}
108+
109+
TEST(TestSuite, convertPoseCovarianceMatrix)
110+
{
111+
centerpoint::Box3D box3d;
112+
box3d.x_variance = 0.1;
113+
box3d.y_variance = 0.2;
114+
box3d.z_variance = 0.3;
115+
box3d.yaw_variance = 0.4;
116+
117+
std::array<double, 36> pose_covariance = centerpoint::convertPoseCovarianceMatrix(box3d);
118+
119+
EXPECT_FLOAT_EQ(pose_covariance[0], 0.1);
120+
EXPECT_FLOAT_EQ(pose_covariance[7], 0.2);
121+
EXPECT_FLOAT_EQ(pose_covariance[14], 0.3);
122+
EXPECT_FLOAT_EQ(pose_covariance[35], 0.4);
123+
}
124+
125+
TEST(TestSuite, convertTwistCovarianceMatrix)
126+
{
127+
centerpoint::Box3D box3d;
128+
box3d.vel_x_variance = 0.1;
129+
box3d.vel_y_variance = 0.2;
130+
131+
std::array<double, 36> twist_covariance = centerpoint::convertTwistCovarianceMatrix(box3d);
132+
133+
EXPECT_FLOAT_EQ(twist_covariance[0], 0.1);
134+
EXPECT_FLOAT_EQ(twist_covariance[7], 0.2);
135+
}
136+
137+
int main(int argc, char ** argv)
138+
{
139+
testing::InitGoogleTest(&argc, argv);
140+
return RUN_ALL_TESTS();
141+
}

0 commit comments

Comments
 (0)