Skip to content

Commit 23bb4a0

Browse files
fix(autoware_lidar_transfusion): set tensor names by matching with predefined values. (#9057)
* set tensor order using api Signed-off-by: Samrat Thapa <samratthapa120@gmail.com> * style(pre-commit): autofix Signed-off-by: Samrat Thapa <samratthapa120@gmail.com> * fix tensor order Signed-off-by: Samrat Thapa <samratthapa120@gmail.com> * style(pre-commit): autofix Signed-off-by: Samrat Thapa <samratthapa120@gmail.com> * style fix Signed-off-by: Samrat Thapa <samratthapa120@gmail.com> * style(pre-commit): autofix --------- Signed-off-by: Samrat Thapa <samratthapa120@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 146be20 commit 23bb4a0

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

perception/autoware_lidar_transfusion/include/autoware/lidar_transfusion/utils.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct Box3D
3636
float yaw;
3737
};
3838

39-
enum NetworkIO { voxels = 0, num_points, coors, cls_score, dir_pred, bbox_pred, ENUM_SIZE };
39+
enum NetworkIO { voxels = 0, num_points, coors, cls_score, bbox_pred, dir_pred, ENUM_SIZE };
4040

4141
// cspell: ignore divup
4242
template <typename T1, typename T2>

perception/autoware_lidar_transfusion/lib/network/network_trt.cpp

+21-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,20 @@
2323
namespace autoware::lidar_transfusion
2424
{
2525

26+
inline NetworkIO nameToNetworkIO(const char * name)
27+
{
28+
static const std::unordered_map<std::string_view, NetworkIO> name_to_enum = {
29+
{"voxels", NetworkIO::voxels}, {"num_points", NetworkIO::num_points},
30+
{"coors", NetworkIO::coors}, {"cls_score0", NetworkIO::cls_score},
31+
{"bbox_pred0", NetworkIO::bbox_pred}, {"dir_cls_pred0", NetworkIO::dir_pred}};
32+
33+
auto it = name_to_enum.find(name);
34+
if (it != name_to_enum.end()) {
35+
return it->second;
36+
}
37+
throw std::runtime_error("Invalid input name: " + std::string(name));
38+
}
39+
2640
std::ostream & operator<<(std::ostream & os, const ProfileDimension & profile)
2741
{
2842
std::string delim = "";
@@ -253,8 +267,14 @@ bool NetworkTRT::validateNetworkIO()
253267
<< ". Actual size: " << engine->getNbIOTensors() << "." << std::endl;
254268
throw std::runtime_error("Failed to initialize TRT network.");
255269
}
270+
271+
// Initialize tensors_names_ with null values
272+
tensors_names_.resize(NetworkIO::ENUM_SIZE, nullptr);
273+
274+
// Loop over the tensor names and place them in the correct positions
256275
for (int i = 0; i < NetworkIO::ENUM_SIZE; ++i) {
257-
tensors_names_.push_back(engine->getIOTensorName(i));
276+
const char * name = engine->getIOTensorName(i);
277+
tensors_names_[nameToNetworkIO(name)] = name;
258278
}
259279

260280
// Log the network IO

0 commit comments

Comments
 (0)