@@ -36,6 +36,45 @@ namespace autoware
36
36
namespace tensorrt_common
37
37
{
38
38
39
+ class TrtErrorRecorder : public nvinfer1 ::IErrorRecorder
40
+ {
41
+ public:
42
+ struct Error
43
+ {
44
+ nvinfer1::ErrorCode code;
45
+ std::string desc;
46
+ };
47
+ const std::vector<Error> & getErrors () const noexcept { return errors_; }
48
+
49
+ private:
50
+ RefCount incRefCount () noexcept override { return ++ref_count_; }
51
+
52
+ RefCount decRefCount () noexcept override { return --ref_count_; }
53
+
54
+ int32_t getNbErrors () const noexcept override { return static_cast <int32_t >(errors_.size ()); }
55
+
56
+ nvinfer1::ErrorCode getErrorCode (int32_t errorIdx) const noexcept override
57
+ {
58
+ return errors_[errorIdx].code ;
59
+ }
60
+
61
+ ErrorDesc getErrorDesc (int32_t errorIdx) const noexcept override
62
+ {
63
+ return errors_[errorIdx].desc .c_str ();
64
+ }
65
+
66
+ bool hasOverflowed () const noexcept override { return false ; }
67
+ bool reportError (nvinfer1::ErrorCode val, ErrorDesc desc) noexcept override
68
+ {
69
+ errors_.push_back ({val, std::string (desc)});
70
+ return false ;
71
+ }
72
+ void clear () noexcept override { errors_.clear (); }
73
+
74
+ std::vector<Error> errors_;
75
+ std::atomic<nvinfer1::IErrorRecorder::RefCount> ref_count_{0 };
76
+ };
77
+
39
78
TrtCommon::TrtCommon (
40
79
const TrtCommonConfig & trt_config, const std::shared_ptr<Profiler> & profiler,
41
80
const std::vector<std::string> & plugin_paths)
@@ -543,19 +582,54 @@ bool TrtCommon::buildEngineFromOnnx()
543
582
return true ;
544
583
}
545
584
585
+ auto setup_error_recorder (TrtUniquePtr<nvinfer1::IRuntime> & runtime)
586
+ {
587
+ auto errorRecorder = std::make_unique<TrtErrorRecorder>();
588
+ runtime->setErrorRecorder (errorRecorder.get ());
589
+ return std::make_pair (std::move (errorRecorder), runtime->getErrorRecorder ());
590
+ }
591
+
592
+ void restore_default_recorder (
593
+ TrtUniquePtr<nvinfer1::IRuntime> & runtime, nvinfer1::IErrorRecorder * defRecorder)
594
+ {
595
+ runtime->setErrorRecorder (defRecorder);
596
+ }
597
+
546
598
bool TrtCommon::loadEngine ()
547
599
{
548
600
std::ifstream engine_file (trt_config_->engine_path );
549
601
std::stringstream engine_buffer;
550
602
engine_buffer << engine_file.rdbuf ();
551
603
std::string engine_str = engine_buffer.str ();
552
604
605
+ auto [errorRecorder, defRecorder] = setup_error_recorder (runtime_);
606
+
553
607
engine_ = TrtUniquePtr<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine (
554
608
reinterpret_cast <const void *>( // NOLINT
555
609
engine_str.data ()),
556
610
engine_str.size ()));
611
+
612
+ restore_default_recorder (runtime_, defRecorder);
613
+
557
614
if (!engine_) {
558
- logger_->log (nvinfer1::ILogger::Severity::kERROR , " Fail to create engine" );
615
+ for (const auto & error : errorRecorder->getErrors ()) {
616
+ auto code = error.code ;
617
+ auto desc = error.desc ;
618
+ logger_->log (
619
+ nvinfer1::ILogger::Severity::kERROR , " Error code: %d, Description: %s" , code, desc.c_str ());
620
+ if (
621
+ code == nvinfer1::ErrorCode::kUNSPECIFIED_ERROR &&
622
+ desc.find (" Serialization assertion stdVersionRead == kSERIALIZATION_VERSION failed" ) !=
623
+ std::string::npos) {
624
+ fs::remove (trt_config_->engine_path );
625
+ logger_->log (
626
+ nvinfer1::ILogger::Severity::kERROR ,
627
+ " Engine file %s removed due to version mismatch. This file will be regenerated, next "
628
+ " time the node starts." ,
629
+ trt_config_->engine_path .c_str ());
630
+ }
631
+ }
632
+ logger_->log (nvinfer1::ILogger::Severity::kERROR , " Failed to create engine" );
559
633
return false ;
560
634
}
561
635
0 commit comments