diff --git a/common/tensorrt_common/CMakeLists.txt b/common/tensorrt_common/CMakeLists.txt index cb24448f1a993..80b101e2afef4 100644 --- a/common/tensorrt_common/CMakeLists.txt +++ b/common/tensorrt_common/CMakeLists.txt @@ -16,6 +16,45 @@ if(NOT (CUDAToolkit_FOUND AND CUDNN_FOUND AND TENSORRT_FOUND)) return() endif() +# Download onnx +set(DATA_PATH "${CMAKE_CURRENT_SOURCE_DIR}/data") +if(NOT EXISTS "${DATA_PATH}") + execute_process(COMMAND mkdir -p ${DATA_PATH}) +endif() +function(download FILE_NAME FILE_HASH) + message(STATUS "Checking and downloading ${FILE_NAME}") + set(FILE_PATH ${DATA_PATH}/${FILE_NAME}) + set(STATUS_CODE 0) + message(STATUS "start ${FILE_NAME}") + if(EXISTS ${FILE_PATH}) + message(STATUS "found ${FILE_NAME}") + file(MD5 ${FILE_PATH} EXISTING_FILE_HASH) + if(${FILE_HASH} STREQUAL ${EXISTING_FILE_HASH}) + message(STATUS "same ${FILE_NAME}") + message(STATUS "File already exists.") + else() + message(STATUS "diff ${FILE_NAME}") + message(STATUS "File hash changes. Downloading now ...") + file(DOWNLOAD https://awf.ml.dev.web.auto/perception/models/${FILE_NAME} ${FILE_PATH} STATUS DOWNLOAD_STATUS TIMEOUT 3600) + list(GET DOWNLOAD_STATUS 0 STATUS_CODE) + list(GET DOWNLOAD_STATUS 1 ERROR_MESSAGE) + endif() + else() + message(STATUS "not found ${FILE_NAME}") + message(STATUS "File doesn't exists. Downloading now ...") + file(DOWNLOAD https://awf.ml.dev.web.auto/perception/models/${FILE_NAME} ${FILE_PATH} STATUS DOWNLOAD_STATUS TIMEOUT 3600) + list(GET DOWNLOAD_STATUS 0 STATUS_CODE) + list(GET DOWNLOAD_STATUS 1 ERROR_MESSAGE) + endif() + if(${STATUS_CODE} EQUAL 0) + message(STATUS "Download completed successfully!") + else() + message(FATAL_ERROR "Error occurred during download: ${ERROR_MESSAGE}") + endif() +endfunction() + +download(yolov5s.onnx 646b295db6089597e79163446d6eedca) + add_library(${PROJECT_NAME} SHARED src/tensorrt_common.cpp src/simple_profiler.cpp @@ -61,6 +100,11 @@ if(BUILD_TESTING) set(ament_cmake_uncrustify_FOUND TRUE) ament_lint_auto_find_test_dependencies() + + file(GLOB_RECURSE test_files test/*.cpp) + ament_add_ros_isolated_gtest(test_tensorrt_common ${test_files}) + target_link_libraries(test_tensorrt_common ${PROJECT_NAME}) + endif() install(TARGETS ${PROJECT_NAME} EXPORT export_${PROJECT_NAME}) diff --git a/common/tensorrt_common/package.xml b/common/tensorrt_common/package.xml index 7d3995f93f7fe..196af86f4541a 100644 --- a/common/tensorrt_common/package.xml +++ b/common/tensorrt_common/package.xml @@ -18,6 +18,7 @@ rclcpp + ament_cmake_ros ament_lint_auto ament_lint_common diff --git a/common/tensorrt_common/test/test_tensorrt_common.cpp b/common/tensorrt_common/test/test_tensorrt_common.cpp new file mode 100644 index 0000000000000..3b60ed84c8c08 --- /dev/null +++ b/common/tensorrt_common/test/test_tensorrt_common.cpp @@ -0,0 +1,43 @@ +// Copyright 2023 TIER IV, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorrt_common/tensorrt_common.hpp" + +#include + +// test get_input_dims function +TEST(TrtCommonTest, TestGetInputDims) +{ + std::string onnx_file_path = + "src/universe/autoware.universe/common/tensorrt_common/data/yolov5s.onnx"; + nvinfer1::Dims input_dims = tensorrt_common::get_input_dims(onnx_file_path); + ASSERT_GT(input_dims.nbDims, 0); +} + +// test is_valid_precision_string function +TEST(TrtCommonTest, TestIsValidPrecisionString) +{ + std::string valid_precision = "fp16"; + std::string invalid_precision = "invalid_precision"; + ASSERT_TRUE(tensorrt_common::is_valid_precision_string(valid_precision)); + ASSERT_FALSE(tensorrt_common::is_valid_precision_string(invalid_precision)); +} + +// In the future, more test cases will be written to test the functionality of TrtCommon class + +int main(int argc, char * argv[]) +{ + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}