From 880519cab5391e79d53d5a7fff6aebed0a716c8d Mon Sep 17 00:00:00 2001 From: Cynthia Liu Date: Thu, 28 Sep 2023 18:01:05 +0800 Subject: [PATCH 1/2] feat(tensorrt_common): add gtest Signed-off-by: Cynthia Liu --- common/tensorrt_common/CMakeLists.txt | 44 +++++++++++++++++++ common/tensorrt_common/package.xml | 1 + .../test/test_tensorrt_common.cpp | 38 ++++++++++++++++ 3 files changed, 83 insertions(+) create mode 100644 common/tensorrt_common/test/test_tensorrt_common.cpp diff --git a/common/tensorrt_common/CMakeLists.txt b/common/tensorrt_common/CMakeLists.txt index cb24448f1a993..80b101e2afef4 100644 --- a/common/tensorrt_common/CMakeLists.txt +++ b/common/tensorrt_common/CMakeLists.txt @@ -16,6 +16,45 @@ if(NOT (CUDAToolkit_FOUND AND CUDNN_FOUND AND TENSORRT_FOUND)) return() endif() +# Download onnx +set(DATA_PATH "${CMAKE_CURRENT_SOURCE_DIR}/data") +if(NOT EXISTS "${DATA_PATH}") + execute_process(COMMAND mkdir -p ${DATA_PATH}) +endif() +function(download FILE_NAME FILE_HASH) + message(STATUS "Checking and downloading ${FILE_NAME}") + set(FILE_PATH ${DATA_PATH}/${FILE_NAME}) + set(STATUS_CODE 0) + message(STATUS "start ${FILE_NAME}") + if(EXISTS ${FILE_PATH}) + message(STATUS "found ${FILE_NAME}") + file(MD5 ${FILE_PATH} EXISTING_FILE_HASH) + if(${FILE_HASH} STREQUAL ${EXISTING_FILE_HASH}) + message(STATUS "same ${FILE_NAME}") + message(STATUS "File already exists.") + else() + message(STATUS "diff ${FILE_NAME}") + message(STATUS "File hash changes. Downloading now ...") + file(DOWNLOAD https://awf.ml.dev.web.auto/perception/models/${FILE_NAME} ${FILE_PATH} STATUS DOWNLOAD_STATUS TIMEOUT 3600) + list(GET DOWNLOAD_STATUS 0 STATUS_CODE) + list(GET DOWNLOAD_STATUS 1 ERROR_MESSAGE) + endif() + else() + message(STATUS "not found ${FILE_NAME}") + message(STATUS "File doesn't exists. Downloading now ...") + file(DOWNLOAD https://awf.ml.dev.web.auto/perception/models/${FILE_NAME} ${FILE_PATH} STATUS DOWNLOAD_STATUS TIMEOUT 3600) + list(GET DOWNLOAD_STATUS 0 STATUS_CODE) + list(GET DOWNLOAD_STATUS 1 ERROR_MESSAGE) + endif() + if(${STATUS_CODE} EQUAL 0) + message(STATUS "Download completed successfully!") + else() + message(FATAL_ERROR "Error occurred during download: ${ERROR_MESSAGE}") + endif() +endfunction() + +download(yolov5s.onnx 646b295db6089597e79163446d6eedca) + add_library(${PROJECT_NAME} SHARED src/tensorrt_common.cpp src/simple_profiler.cpp @@ -61,6 +100,11 @@ if(BUILD_TESTING) set(ament_cmake_uncrustify_FOUND TRUE) ament_lint_auto_find_test_dependencies() + + file(GLOB_RECURSE test_files test/*.cpp) + ament_add_ros_isolated_gtest(test_tensorrt_common ${test_files}) + target_link_libraries(test_tensorrt_common ${PROJECT_NAME}) + endif() install(TARGETS ${PROJECT_NAME} EXPORT export_${PROJECT_NAME}) diff --git a/common/tensorrt_common/package.xml b/common/tensorrt_common/package.xml index 7d3995f93f7fe..196af86f4541a 100644 --- a/common/tensorrt_common/package.xml +++ b/common/tensorrt_common/package.xml @@ -18,6 +18,7 @@ rclcpp + ament_cmake_ros ament_lint_auto ament_lint_common diff --git a/common/tensorrt_common/test/test_tensorrt_common.cpp b/common/tensorrt_common/test/test_tensorrt_common.cpp new file mode 100644 index 0000000000000..0f783a22562f7 --- /dev/null +++ b/common/tensorrt_common/test/test_tensorrt_common.cpp @@ -0,0 +1,38 @@ +// Copyright 2023 TIER IV, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "tensorrt_common/tensorrt_common.hpp" + +// test get_input_dims function +TEST(TrtCommonTest, TestGetInputDims) { + std::string onnx_file_path = "src/universe/autoware.universe/common/tensorrt_common/data/yolov5s.onnx"; + nvinfer1::Dims input_dims = tensorrt_common::get_input_dims(onnx_file_path); + ASSERT_GT(input_dims.nbDims, 0); +} + +// test is_valid_precision_string function +TEST(TrtCommonTest, TestIsValidPrecisionString) { + std::string valid_precision = "fp16"; + std::string invalid_precision = "invalid_precision"; + ASSERT_TRUE(tensorrt_common::is_valid_precision_string(valid_precision)); + ASSERT_FALSE(tensorrt_common::is_valid_precision_string(invalid_precision)); +} + +// In the future, more test cases will be written to test the functionality of TrtCommon class + +int main(int argc, char * argv[]) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From 8ade1a891f86ee04e661e02d58a7d476a092bd6d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Sep 2023 10:07:34 +0000 Subject: [PATCH 2/2] style(pre-commit): autofix --- .../test/test_tensorrt_common.cpp | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/common/tensorrt_common/test/test_tensorrt_common.cpp b/common/tensorrt_common/test/test_tensorrt_common.cpp index 0f783a22562f7..3b60ed84c8c08 100644 --- a/common/tensorrt_common/test/test_tensorrt_common.cpp +++ b/common/tensorrt_common/test/test_tensorrt_common.cpp @@ -12,18 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include "tensorrt_common/tensorrt_common.hpp" +#include + // test get_input_dims function -TEST(TrtCommonTest, TestGetInputDims) { - std::string onnx_file_path = "src/universe/autoware.universe/common/tensorrt_common/data/yolov5s.onnx"; +TEST(TrtCommonTest, TestGetInputDims) +{ + std::string onnx_file_path = + "src/universe/autoware.universe/common/tensorrt_common/data/yolov5s.onnx"; nvinfer1::Dims input_dims = tensorrt_common::get_input_dims(onnx_file_path); ASSERT_GT(input_dims.nbDims, 0); } // test is_valid_precision_string function -TEST(TrtCommonTest, TestIsValidPrecisionString) { +TEST(TrtCommonTest, TestIsValidPrecisionString) +{ std::string valid_precision = "fp16"; std::string invalid_precision = "invalid_precision"; ASSERT_TRUE(tensorrt_common::is_valid_precision_string(valid_precision)); @@ -32,7 +36,8 @@ TEST(TrtCommonTest, TestIsValidPrecisionString) { // In the future, more test cases will be written to test the functionality of TrtCommon class -int main(int argc, char * argv[]) { - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +int main(int argc, char * argv[]) +{ + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); }