Skip to content

Commit ad086ba

Browse files
committed
fix(autoware_tensorrt_common): invalid engine model file
Signed-off-by: Grzegorz Głowacki <gglowacki@autonomous-systems.pl>
1 parent 1bebec9 commit ad086ba

File tree

1 file changed

+75
-1
lines changed

1 file changed

+75
-1
lines changed

perception/autoware_tensorrt_common/src/tensorrt_common.cpp

+75-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,45 @@ namespace autoware
3636
namespace tensorrt_common
3737
{
3838

39+
class TrtErrorRecorder : public nvinfer1::IErrorRecorder
40+
{
41+
public:
42+
struct Error
43+
{
44+
nvinfer1::ErrorCode code;
45+
std::string desc;
46+
};
47+
const std::vector<Error> & getErrors() const noexcept { return errors_; }
48+
49+
private:
50+
RefCount incRefCount() noexcept override { return ++ref_count_; }
51+
52+
RefCount decRefCount() noexcept override { return --ref_count_; }
53+
54+
int32_t getNbErrors() const noexcept override { return static_cast<int32_t>(errors_.size()); }
55+
56+
nvinfer1::ErrorCode getErrorCode(int32_t errorIdx) const noexcept override
57+
{
58+
return errors_[errorIdx].code;
59+
}
60+
61+
ErrorDesc getErrorDesc(int32_t errorIdx) const noexcept override
62+
{
63+
return errors_[errorIdx].desc.c_str();
64+
}
65+
66+
bool hasOverflowed() const noexcept override { return false; }
67+
bool reportError(nvinfer1::ErrorCode val, ErrorDesc desc) noexcept override
68+
{
69+
errors_.push_back({val, std::string(desc)});
70+
return false;
71+
}
72+
void clear() noexcept override { errors_.clear(); }
73+
74+
std::vector<Error> errors_;
75+
std::atomic<nvinfer1::IErrorRecorder::RefCount> ref_count_{0};
76+
};
77+
3978
TrtCommon::TrtCommon(
4079
const TrtCommonConfig & trt_config, const std::shared_ptr<Profiler> & profiler,
4180
const std::vector<std::string> & plugin_paths)
@@ -543,19 +582,54 @@ bool TrtCommon::buildEngineFromOnnx()
543582
return true;
544583
}
545584

585+
auto setup_error_recorder(TrtUniquePtr<nvinfer1::IRuntime> & runtime)
586+
{
587+
auto errorRecorder = std::make_unique<TrtErrorRecorder>();
588+
runtime->setErrorRecorder(errorRecorder.get());
589+
return std::make_pair(std::move(errorRecorder), runtime->getErrorRecorder());
590+
}
591+
592+
void restore_default_recorder(
593+
TrtUniquePtr<nvinfer1::IRuntime> & runtime, nvinfer1::IErrorRecorder * defRecorder)
594+
{
595+
runtime->setErrorRecorder(defRecorder);
596+
}
597+
546598
bool TrtCommon::loadEngine()
547599
{
548600
std::ifstream engine_file(trt_config_->engine_path);
549601
std::stringstream engine_buffer;
550602
engine_buffer << engine_file.rdbuf();
551603
std::string engine_str = engine_buffer.str();
552604

605+
auto [errorRecorder, defRecorder] = setup_error_recorder(runtime_);
606+
553607
engine_ = TrtUniquePtr<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(
554608
reinterpret_cast<const void *>( // NOLINT
555609
engine_str.data()),
556610
engine_str.size()));
611+
612+
restore_default_recorder(runtime_, defRecorder);
613+
557614
if (!engine_) {
558-
logger_->log(nvinfer1::ILogger::Severity::kERROR, "Fail to create engine");
615+
for (const auto & error : errorRecorder->getErrors()) {
616+
auto code = error.code;
617+
auto desc = error.desc;
618+
logger_->log(
619+
nvinfer1::ILogger::Severity::kERROR, "Error code: %d, Description: %s", code, desc.c_str());
620+
if (
621+
code == nvinfer1::ErrorCode::kUNSPECIFIED_ERROR &&
622+
desc.find("Serialization assertion stdVersionRead == kSERIALIZATION_VERSION failed") !=
623+
std::string::npos) {
624+
fs::remove(trt_config_->engine_path);
625+
logger_->log(
626+
nvinfer1::ILogger::Severity::kERROR,
627+
"Engine file %s removed due to version mismatch. This file will be regenerated, next "
628+
"time the node starts.",
629+
trt_config_->engine_path.c_str());
630+
}
631+
}
632+
logger_->log(nvinfer1::ILogger::Severity::kERROR, "Failed to create engine");
559633
return false;
560634
}
561635

0 commit comments

Comments
 (0)