From f67a830deee43670eba159ff6a978d1a88922222 Mon Sep 17 00:00:00 2001
From: Amadeusz Szymko <amadeusz.szymko.2@tier4.jp>
Date: Fri, 21 Mar 2025 16:05:13 +0900
Subject: [PATCH] autoware_tenssort_common): validate TensorRT engine version
 for cached engine

Signed-off-by: Amadeusz Szymko <amadeusz.szymko.2@tier4.jp>
---
 .../tensorrt_common/tensorrt_common.hpp       | 11 ++++++
 .../src/tensorrt_common.cpp                   | 37 ++++++++++++++++++-
 2 files changed, 47 insertions(+), 1 deletion(-)

diff --git a/perception/autoware_tensorrt_common/include/autoware/tensorrt_common/tensorrt_common.hpp b/perception/autoware_tensorrt_common/include/autoware/tensorrt_common/tensorrt_common.hpp
index 9355732962e23..d35efbfd0ab4a 100644
--- a/perception/autoware_tensorrt_common/include/autoware/tensorrt_common/tensorrt_common.hpp
+++ b/perception/autoware_tensorrt_common/include/autoware/tensorrt_common/tensorrt_common.hpp
@@ -46,6 +46,10 @@ using ProfileDimsPtr = std::unique_ptr<std::vector<ProfileDims>>;
 using TensorsVec = std::vector<std::pair<void *, nvinfer1::Dims>>;
 using TensorsMap = std::unordered_map<const char *, std::pair<void *, nvinfer1::Dims>>;
 
+constexpr int TRT_MAJOR_IDX = 24;
+constexpr int TRT_MINOR_IDX = 25;
+constexpr int TRT_PATCH_IDX = 26;
+
 /**
  * @class TrtCommon
  * @brief TensorRT common library.
@@ -317,6 +321,13 @@ class TrtCommon  // NOLINT
    */
   bool buildEngineFromOnnx();
 
+  /**
+   * @brief Validate TensorRT engine.
+   *
+   * @return Whether TensorRT version used for building engine is compatible.
+   */
+  bool validateEngine();
+
   /**
    * @brief Load TensorRT engine.
    *
diff --git a/perception/autoware_tensorrt_common/src/tensorrt_common.cpp b/perception/autoware_tensorrt_common/src/tensorrt_common.cpp
index ba422277416ab..1cd9c58667992 100644
--- a/perception/autoware_tensorrt_common/src/tensorrt_common.cpp
+++ b/perception/autoware_tensorrt_common/src/tensorrt_common.cpp
@@ -115,7 +115,15 @@ bool TrtCommon::setup(ProfileDimsPtr profile_dims, NetworkIOPtr network_io)
   // Load engine file if it exists
   if (fs::exists(trt_config_->engine_path)) {
     logger_->log(nvinfer1::ILogger::Severity::kINFO, "Loading engine");
-    if (!loadEngine()) {
+    if (!validateEngine()) {
+      logger_->log(
+        nvinfer1::ILogger::Severity::kWARNING,
+        "Engine validation failed for loaded engine from file. Rebuilding engine");
+      // Rebuild engine if version mismatch occurred
+      if (!build_engine_with_log()) {
+        return false;
+      }
+    } else if (!loadEngine()) {
       return false;
     }
     logger_->log(nvinfer1::ILogger::Severity::kINFO, "Network validation");
@@ -543,6 +551,33 @@ bool TrtCommon::buildEngineFromOnnx()
   return true;
 }
 
+bool TrtCommon::validateEngine()
+{
+#if (NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 8600
+  std::ifstream engine_file(trt_config_->engine_path);
+  std::stringstream engine_buffer;
+  engine_buffer << engine_file.rdbuf();
+  std::string engine_str = engine_buffer.str();
+
+  auto const blob = reinterpret_cast<uint8_t *>(engine_str.data());
+  logger_->log(
+    nvinfer1::ILogger::Severity::kINFO, "Plan was created with TensorRT %d.%d.%d",
+    static_cast<int32_t>(blob[TRT_MAJOR_IDX]), static_cast<int32_t>(blob[TRT_MINOR_IDX]),
+    static_cast<int32_t>(blob[TRT_PATCH_IDX]));
+  auto plan_ver = static_cast<int32_t>(blob[TRT_MAJOR_IDX]) * 1000 +
+                  static_cast<int32_t>(blob[TRT_MINOR_IDX]) * 100 +
+                  static_cast<int32_t>(blob[TRT_PATCH_IDX]);
+  if (plan_ver != (NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSORRT_PATCH) {
+    logger_->log(
+      nvinfer1::ILogger::Severity::kWARNING,
+      "Plan was created with a different version of TensorRT! Current version: %d.%d.%d",
+      NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH);
+    return false;
+  }
+#endif
+  return true;
+}
+
 bool TrtCommon::loadEngine()
 {
   std::ifstream engine_file(trt_config_->engine_path);