-
Notifications
You must be signed in to change notification settings - Fork 706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(autoware_tenssort_common): validate TensorRT engine version for cached engine #10320
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -46,6 +46,10 @@ using ProfileDimsPtr = std::unique_ptr<std::vector<ProfileDims>>; | |||||
using TensorsVec = std::vector<std::pair<void *, nvinfer1::Dims>>; | ||||||
using TensorsMap = std::unordered_map<const char *, std::pair<void *, nvinfer1::Dims>>; | ||||||
|
||||||
constexpr int TRT_MAJOR_IDX = 24; | ||||||
constexpr int TRT_MINOR_IDX = 25; | ||||||
constexpr int TRT_PATCH_IDX = 26; | ||||||
|
||||||
/** | ||||||
* @class TrtCommon | ||||||
* @brief TensorRT common library. | ||||||
|
@@ -317,6 +321,13 @@ class TrtCommon // NOLINT | |||||
*/ | ||||||
bool buildEngineFromOnnx(); | ||||||
|
||||||
/** | ||||||
* @brief Validate TensorRT engine. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
* | ||||||
* @return Whether TensorRT version used for building engine is compatible. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
*/ | ||||||
bool validateEngine(); | ||||||
|
||||||
/** | ||||||
* @brief Load TensorRT engine. | ||||||
* | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -115,7 +115,15 @@ | |||||||
// Load engine file if it exists | ||||||||
if (fs::exists(trt_config_->engine_path)) { | ||||||||
logger_->log(nvinfer1::ILogger::Severity::kINFO, "Loading engine"); | ||||||||
if (!loadEngine()) { | ||||||||
if (!validateEngine()) { | ||||||||
logger_->log( | ||||||||
nvinfer1::ILogger::Severity::kWARNING, | ||||||||
"Engine validation failed for loaded engine from file. Rebuilding engine"); | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
// Rebuild engine if version mismatch occurred | ||||||||
if (!build_engine_with_log()) { | ||||||||
return false; | ||||||||
} | ||||||||
} else if (!loadEngine()) { | ||||||||
return false; | ||||||||
} | ||||||||
logger_->log(nvinfer1::ILogger::Severity::kINFO, "Network validation"); | ||||||||
|
@@ -543,6 +551,33 @@ | |||||||
return true; | ||||||||
} | ||||||||
|
||||||||
bool TrtCommon::validateEngine() | ||||||||
{ | ||||||||
#if (NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 8600 | ||||||||
std::ifstream engine_file(trt_config_->engine_path); | ||||||||
std::stringstream engine_buffer; | ||||||||
engine_buffer << engine_file.rdbuf(); | ||||||||
std::string engine_str = engine_buffer.str(); | ||||||||
|
||||||||
auto const blob = reinterpret_cast<uint8_t *>(engine_str.data()); | ||||||||
logger_->log( | ||||||||
nvinfer1::ILogger::Severity::kINFO, "Plan was created with TensorRT %d.%d.%d", | ||||||||
static_cast<int32_t>(blob[TRT_MAJOR_IDX]), static_cast<int32_t>(blob[TRT_MINOR_IDX]), | ||||||||
static_cast<int32_t>(blob[TRT_PATCH_IDX])); | ||||||||
auto plan_ver = static_cast<int32_t>(blob[TRT_MAJOR_IDX]) * 1000 + | ||||||||
static_cast<int32_t>(blob[TRT_MINOR_IDX]) * 100 + | ||||||||
static_cast<int32_t>(blob[TRT_PATCH_IDX]); | ||||||||
if (plan_ver != (NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSORRT_PATCH) { | ||||||||
logger_->log( | ||||||||
nvinfer1::ILogger::Severity::kWARNING, | ||||||||
"Plan was created with a different version of TensorRT! Current version: %d.%d.%d", | ||||||||
NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH); | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
return false; | ||||||||
} | ||||||||
#endif | ||||||||
return true; | ||||||||
} | ||||||||
|
||||||||
bool TrtCommon::loadEngine() | ||||||||
{ | ||||||||
std::ifstream engine_file(trt_config_->engine_path); | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that these magic values come from an official source (answer from NVIDIA). Could you leave a evidence here so that people know where that came from (and hopefully can track and fix it if/when in the future NVIDIA changes it?)