Skip to content

Commit 0e3f454

Browse files
committed
Add visibility score computation to cpp
1 parent 2d91938 commit 0e3f454

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/cpp/models/src/keypoint_detection.cpp

+12-4
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,24 @@ DetectedKeypoints decode_simcc(const cv::Mat& simcc_x,
5555
const cv::Point2f& extra_scale = cv::Point2f(1.f, 1.f),
5656
const cv::Point2i& extra_offset = cv::Point2f(0.f, 0.f),
5757
bool apply_softmax = false,
58-
float simcc_split_ratio = 2.0f) {
58+
float simcc_split_ratio = 2.0f,
59+
float decode_beta = 150.0f,
60+
float sigma = 6.0f) {
5961
cv::Mat x_locs, max_val_x;
60-
colArgMax(simcc_x, x_locs, max_val_x, apply_softmax);
62+
colArgMax(simcc_x, x_locs, max_val_x, false);
6163

6264
cv::Mat y_locs, max_val_y;
63-
colArgMax(simcc_y, y_locs, max_val_y, apply_softmax);
65+
colArgMax(simcc_y, y_locs, max_val_x, false);
66+
67+
if (apply_softmax) {
68+
cv::Mat tmp_locs;
69+
colArgMax(decode_beta * sigma * simcc_x, tmp_locs, max_val_x, true);
70+
colArgMax(decode_beta * sigma * simcc_y, tmp_locs, max_val_y, true);
71+
}
6472

6573
std::vector<cv::Point2f> keypoints(x_locs.rows);
6674
cv::Mat scores = cv::Mat::zeros(x_locs.rows, 1, CV_32F);
67-
for (int i = 0; i < x_locs.rows; i++) {
75+
for (int i = 0; i < x_locs.rows; ++i) {
6876
keypoints[i] = cv::Point2f((x_locs.at<int>(i) - extra_offset.x) * extra_scale.x,
6977
(y_locs.at<int>(i) - extra_offset.y) * extra_scale.y) /
7078
simcc_split_ratio;

0 commit comments

Comments
 (0)