1
- // Copyright 2021 Tier IV, Inc.
1
+ // Copyright 2021 TIER IV, Inc.
2
2
//
3
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
4
// you may not use this file except in compliance with the License.
18
18
#include < config.hpp>
19
19
#include < cuda_utils.hpp>
20
20
#include < network_trt.hpp>
21
- #include < tier4_autoware_utils/math/constants .hpp>
21
+ #include < postprocess_kernel .hpp>
22
22
#include < voxel_generator.hpp>
23
23
24
24
#include < sensor_msgs/msg/point_cloud2.hpp>
25
25
26
26
#include < pcl/point_cloud.h>
27
27
#include < pcl/point_types.h>
28
- #include < torch/script.h>
29
28
30
29
#include < memory>
31
30
#include < string>
@@ -37,75 +36,72 @@ namespace centerpoint
37
36
class NetworkParam
38
37
{
39
38
public:
40
- NetworkParam (
41
- std::string onnx_path, std::string engine_path, std::string pt_path, std::string trt_precision,
42
- const bool use_trt)
39
+ NetworkParam (std::string onnx_path, std::string engine_path, std::string trt_precision)
43
40
: onnx_path_(std::move(onnx_path)),
44
41
engine_path_ (std::move(engine_path)),
45
- pt_path_(std::move(pt_path)),
46
- trt_precision_(std::move(trt_precision)),
47
- use_trt_(use_trt)
42
+ trt_precision_(std::move(trt_precision))
48
43
{
49
44
}
50
45
51
46
std::string onnx_path () const { return onnx_path_; }
52
47
std::string engine_path () const { return engine_path_; }
53
- std::string pt_path () const { return pt_path_; }
54
48
std::string trt_precision () const { return trt_precision_; }
55
- bool use_trt () const { return use_trt_; }
56
49
57
50
private:
58
51
std::string onnx_path_;
59
52
std::string engine_path_;
60
- std::string pt_path_;
61
53
std::string trt_precision_;
62
- bool use_trt_;
63
54
};
64
55
65
56
class CenterPointTRT
66
57
{
67
58
public:
68
59
explicit CenterPointTRT (
69
- const int num_class, const NetworkParam & encoder_param , const NetworkParam & head_param ,
70
- const DensificationParam & densification_param);
60
+ const std:: size_t num_class, const float score_threshold , const NetworkParam & encoder_param ,
61
+ const NetworkParam & head_param, const DensificationParam & densification_param);
71
62
72
63
~CenterPointTRT ();
73
64
74
- std::vector<float > detect (
75
- const sensor_msgs::msg::PointCloud2 &, const tf2_ros::Buffer & tf_buffer);
65
+ bool detect (
66
+ const sensor_msgs::msg::PointCloud2 & input_pointcloud_msg, const tf2_ros::Buffer & tf_buffer,
67
+ std::vector<Box3D> & det_boxes3d);
76
68
77
69
private:
78
- bool initPtr (bool use_encoder_trt, bool use_head_trt);
79
-
80
- bool loadTorchScript (torch::jit::script::Module & module, const std::string & model_path);
81
-
82
- static at::Tensor createInputFeatures (
83
- const at::Tensor & voxels, const at::Tensor & coords, const at::Tensor & voxel_num_points);
84
-
85
- static at::Tensor scatterPillarFeatures (
86
- const at::Tensor & pillar_features, const at::Tensor & coordinates);
87
-
88
- at::Tensor generatePredictedBoxes ();
89
-
90
- std::unique_ptr<VoxelGeneratorTemplate> vg_ptr_ = nullptr ;
91
- torch::jit::script::Module encoder_pt_;
92
- torch::jit::script::Module head_pt_;
93
- std::unique_ptr<VoxelEncoderTRT> encoder_trt_ptr_ = nullptr ;
94
- std::unique_ptr<HeadTRT> head_trt_ptr_ = nullptr ;
95
- c10::Device device_ = torch::kCUDA ;
96
- cudaStream_t stream_ = nullptr ;
97
-
98
- int num_class_{0 };
99
- at::Tensor voxels_t_;
100
- at::Tensor coordinates_t_;
101
- at::Tensor num_points_per_voxel_t_;
102
- at::Tensor output_pillar_feature_t_;
103
- at::Tensor output_heatmap_t_;
104
- at::Tensor output_offset_t_;
105
- at::Tensor output_z_t_;
106
- at::Tensor output_dim_t_;
107
- at::Tensor output_rot_t_;
108
- at::Tensor output_vel_t_;
70
+ void initPtr ();
71
+
72
+ bool preprocess (
73
+ const sensor_msgs::msg::PointCloud2 & input_pointcloud_msg, const tf2_ros::Buffer & tf_buffer);
74
+
75
+ void inference ();
76
+
77
+ void postProcess (std::vector<Box3D> & det_boxes3d);
78
+
79
+ std::unique_ptr<VoxelGeneratorTemplate> vg_ptr_{nullptr };
80
+ std::unique_ptr<VoxelEncoderTRT> encoder_trt_ptr_{nullptr };
81
+ std::unique_ptr<HeadTRT> head_trt_ptr_{nullptr };
82
+ std::unique_ptr<PostProcessCUDA> post_proc_ptr_{nullptr };
83
+ cudaStream_t stream_{nullptr };
84
+
85
+ bool verbose_{false };
86
+ std::size_t num_class_{0 };
87
+ std::size_t num_voxels_{0 };
88
+ std::size_t encoder_in_feature_size_{0 };
89
+ std::size_t spatial_features_size_{0 };
90
+ std::vector<float > voxels_;
91
+ std::vector<int > coordinates_;
92
+ std::vector<float > num_points_per_voxel_;
93
+ cuda::unique_ptr<float []> voxels_d_{nullptr };
94
+ cuda::unique_ptr<int []> coordinates_d_{nullptr };
95
+ cuda::unique_ptr<float []> num_points_per_voxel_d_{nullptr };
96
+ cuda::unique_ptr<float []> encoder_in_features_d_{nullptr };
97
+ cuda::unique_ptr<float []> pillar_features_d_{nullptr };
98
+ cuda::unique_ptr<float []> spatial_features_d_{nullptr };
99
+ cuda::unique_ptr<float []> head_out_heatmap_d_{nullptr };
100
+ cuda::unique_ptr<float []> head_out_offset_d_{nullptr };
101
+ cuda::unique_ptr<float []> head_out_z_d_{nullptr };
102
+ cuda::unique_ptr<float []> head_out_dim_d_{nullptr };
103
+ cuda::unique_ptr<float []> head_out_rot_d_{nullptr };
104
+ cuda::unique_ptr<float []> head_out_vel_d_{nullptr };
109
105
};
110
106
111
107
} // namespace centerpoint
0 commit comments