Skip to content

Commit 7b8db04

Browse files
committed
Fix softmax accuracy issue
1 parent 29069c5 commit 7b8db04

File tree

2 files changed

+23
-40
lines changed

2 files changed

+23
-40
lines changed

src/cpp/models/src/keypoint_detection.cpp

+18-38
Original file line numberDiff line numberDiff line change
@@ -18,66 +18,46 @@
1818

1919
namespace {
2020

21-
void colArgMax(const cv::Mat& src, cv::Mat& dst_locs, cv::Mat& dst_values) {
21+
void colArgMax(const cv::Mat& src, cv::Mat& dst_locs, cv::Mat& dst_values, bool apply_softmax = false, float eps = 1e-6f) {
2222
dst_locs = cv::Mat::zeros(src.rows, 1, CV_32S);
2323
dst_values = cv::Mat::zeros(src.rows, 1, CV_32F);
2424

2525
for (int row = 0; row < src.rows; ++row) {
2626
const float* ptr_row = src.ptr<float>(row);
2727
int max_val_idx = 0;
28-
dst_values.at<float>(row) = ptr_row[max_val_idx];
28+
float max_val = ptr_row[0];
2929
for (int col = 1; col < src.cols; ++col) {
30-
if (ptr_row[col] > ptr_row[max_val_idx]) {
30+
if (ptr_row[col] > max_val) {
3131
max_val_idx = col;
3232
dst_locs.at<int>(row) = max_val_idx;
33-
dst_values.at<float>(row) = ptr_row[col];
33+
max_val = ptr_row[col];
3434
}
3535
}
36-
}
37-
}
38-
39-
cv::Mat softmax_row(const cv::Mat& src) {
40-
cv::Mat result = src.clone();
4136

42-
for (int row = 0; row < result.rows; ++row) {
43-
float* ptr_row = result.ptr<float>(row);
44-
float max_val = ptr_row[0];
45-
for (int col = 1; col < result.cols; ++col) {
46-
if (ptr_row[col] > max_val) {
47-
max_val = ptr_row[col];
37+
if (apply_softmax) {
38+
float sum = 0.0f;
39+
for (int col = 0; col < src.cols; ++col) {
40+
sum += exp(ptr_row[col] - max_val);
4841
}
42+
dst_values.at<float>(row) = exp(ptr_row[max_val_idx] - max_val) / (sum + eps);
4943
}
50-
float sum = 0.0f;
51-
for (int col = 0; col < result.cols; col++) {
52-
ptr_row[col] = exp(ptr_row[col] - max_val);
53-
sum += ptr_row[col];
54-
}
55-
for (int col = 0; col < result.cols; ++col) {
56-
ptr_row[col] /= sum;
44+
else {
45+
dst_values.at<float>(row) = max_val;
5746
}
5847
}
59-
60-
return result;
6148
}
6249

63-
DetectedKeypoints decode_simcc(const cv::Mat& simcc_x_input,
64-
const cv::Mat& simcc_y_input,
50+
DetectedKeypoints decode_simcc(const cv::Mat& simcc_x,
51+
const cv::Mat& simcc_y,
6552
const cv::Point2f& extra_scale = cv::Point2f(1.f, 1.f),
66-
float simcc_split_ratio = 2.0f,
67-
bool apply_softmax = false) {
68-
cv::Mat simcc_x = simcc_x_input;
69-
cv::Mat simcc_y = simcc_y_input;
70-
71-
if (apply_softmax) {
72-
simcc_x = softmax_row(simcc_x);
73-
simcc_x = softmax_row(simcc_y);
74-
}
75-
53+
bool apply_softmax = false,
54+
float simcc_split_ratio = 2.0f
55+
) {
7656
cv::Mat x_locs, max_val_x;
77-
colArgMax(simcc_x, x_locs, max_val_x);
57+
colArgMax(simcc_x, x_locs, max_val_x, apply_softmax);
7858

7959
cv::Mat y_locs, max_val_y;
80-
colArgMax(simcc_y, y_locs, max_val_y);
60+
colArgMax(simcc_y, y_locs, max_val_y, apply_softmax);
8161

8262
std::vector<cv::Point2f> keypoints(x_locs.rows);
8363
cv::Mat scores = cv::Mat::zeros(x_locs.rows, 1, CV_32F);

src/python/model_api/models/keypoint_detection.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def _get_simcc_maximum(
161161
simcc_x: np.ndarray,
162162
simcc_y: np.ndarray,
163163
apply_softmax: bool = False,
164+
softmax_eps: float = 1e-06,
164165
) -> tuple[np.ndarray, np.ndarray]:
165166
"""Get maximum response location and value from simcc representations.
166167
@@ -175,6 +176,8 @@ def _get_simcc_maximum(
175176
simcc_y (np.ndarray): y-axis SimCC in shape (K, Hy) or (N, K, Hy)
176177
apply_softmax (bool): whether to apply softmax on the heatmap.
177178
Defaults to False.
179+
softmax_eps (flat): a constant to avoid division by zero in softmax.
180+
Defaults to 1e-6.
178181
179182
Returns:
180183
tuple:
@@ -204,8 +207,8 @@ def _get_simcc_maximum(
204207
simcc_x = simcc_x - np.max(simcc_x, axis=1, keepdims=True)
205208
simcc_y = simcc_y - np.max(simcc_y, axis=1, keepdims=True)
206209
ex, ey = np.exp(simcc_x), np.exp(simcc_y)
207-
simcc_x = ex / np.sum(ex, axis=1, keepdims=True)
208-
simcc_y = ey / np.sum(ey, axis=1, keepdims=True)
210+
simcc_x = ex / (np.sum(ex, axis=1, keepdims=True) + softmax_eps)
211+
simcc_y = ey / (np.sum(ey, axis=1, keepdims=True) + softmax_eps)
209212

210213
x_locs = np.argmax(simcc_x, axis=1)
211214
y_locs = np.argmax(simcc_y, axis=1)

0 commit comments

Comments
 (0)