Skip to content

Commit 379ba5d

Browse files
committed
Add softmax to kp det wrapper
1 parent 16818b0 commit 379ba5d

File tree

3 files changed

+61
-8
lines changed

3 files changed

+61
-8
lines changed

model_api/cpp/models/include/models/keypoint_detection.h

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class KeypointDetectionModel : public ImageModel {
4545
static std::string ModelType;
4646

4747
protected:
48+
bool apply_softmax = true;
4849

4950
void prepareInputsOutputs(std::shared_ptr<ov::Model>& model) override;
5051
void updateModelInfo() override;

model_api/cpp/models/src/keypoint_detection.cpp

+40-4
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ void colArgMax(const cv::Mat& src, cv::Mat& dst_locs, cv::Mat& dst_values) {
3434
dst_locs = cv::Mat::zeros(src.rows, 1, CV_32S);
3535
dst_values = cv::Mat::zeros(src.rows, 1, CV_32F);
3636

37-
for (int row = 0; row < src.rows; row++) {
37+
for (int row = 0; row < src.rows; ++row) {
3838
const float *ptr_row = src.ptr<float>(row);
3939
int max_val_idx = 0;
4040
dst_values.at<float>(row) = ptr_row[max_val_idx];
@@ -48,9 +48,44 @@ void colArgMax(const cv::Mat& src, cv::Mat& dst_locs, cv::Mat& dst_values) {
4848
}
4949
}
5050

51-
DetectedKeypoints decode_simcc(const cv::Mat& simcc_x, const cv::Mat& simcc_y,
51+
52+
cv::Mat softmax_row(const cv::Mat& src) {
53+
cv::Mat result = src.clone();
54+
55+
for (int row = 0; row < result.rows; ++row) {
56+
float* ptr_row = result.ptr<float>(row);
57+
float max_val = ptr_row[0];
58+
for (int col = 1; col < result.cols; ++col) {
59+
if (ptr_row[col] > max_val) {
60+
max_val = ptr_row[col];
61+
}
62+
}
63+
float sum = 0.0f;
64+
for (int col = 0; col < result.cols; col++) {
65+
ptr_row[col] = exp(ptr_row[col] - max_val);
66+
sum += ptr_row[col];
67+
}
68+
for (int col = 0; col < result.cols; ++col) {
69+
ptr_row[col] /= sum;
70+
}
71+
}
72+
73+
return result;
74+
}
75+
76+
77+
DetectedKeypoints decode_simcc(const cv::Mat& simcc_x_input, const cv::Mat& simcc_y_input,
5278
const cv::Point2f& extra_scale = cv::Point2f(1.f, 1.f),
53-
float simcc_split_ratio = 2.0f) {
79+
float simcc_split_ratio = 2.0f,
80+
bool apply_softmax=false) {
81+
cv::Mat simcc_x = simcc_x_input;
82+
cv::Mat simcc_y = simcc_y_input;
83+
84+
if (apply_softmax) {
85+
simcc_x = softmax_row(simcc_x);
86+
simcc_x = softmax_row(simcc_y);
87+
}
88+
5489
cv::Mat x_locs, max_val_x;
5590
colArgMax(simcc_x, x_locs, max_val_x);
5691

@@ -77,6 +112,7 @@ std::string KeypointDetectionModel::ModelType = "keypoint_detection";
77112

78113
void KeypointDetectionModel::init_from_config(const ov::AnyMap& top_priority, const ov::AnyMap& mid_priority) {
79114
labels = get_from_any_maps("labels", top_priority, mid_priority, labels);
115+
apply_softmax = get_from_any_maps("apply_softmax", top_priority, mid_priority, apply_softmax);
80116
}
81117

82118
KeypointDetectionModel::KeypointDetectionModel(std::shared_ptr<ov::Model>& model, const ov::AnyMap& configuration) : ImageModel(model, configuration) {
@@ -200,7 +236,7 @@ std::unique_ptr<ResultBase> KeypointDetectionModel::postprocess(InferenceResult&
200236
float inverted_scale_x = static_cast<float>(image_data.inputImgWidth) / netInputWidth,
201237
inverted_scale_y = static_cast<float>(image_data.inputImgHeight) / netInputHeight;
202238

203-
result->poses.emplace_back(decode_simcc(pred_x_mat, pred_y_mat, {inverted_scale_x, inverted_scale_y}));
239+
result->poses.emplace_back(decode_simcc(pred_x_mat, pred_y_mat, {inverted_scale_x, inverted_scale_y}, apply_softmax));
204240
return std::unique_ptr<ResultBase>(result);
205241
}
206242

model_api/python/model_api/models/keypoint_detection.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222

2323
from .image_model import ImageModel
24-
from .types import ListValue
24+
from .types import BooleanValue, ListValue
2525
from .utils import DetectedKeypoints, Detection
2626

2727

@@ -59,7 +59,7 @@ def postprocess(
5959
DetectedKeypoints: detected keypoints
6060
"""
6161
encoded_kps = list(outputs.values())
62-
batch_keypoints, batch_scores = _decode_simcc(*encoded_kps)
62+
batch_keypoints, batch_scores = _decode_simcc(*encoded_kps, apply_softmax=self.apply_softmax)
6363
orig_h, orig_w = meta["original_shape"][:2]
6464
kp_scale_h = orig_h / self.h
6565
kp_scale_w = orig_w / self.w
@@ -74,6 +74,9 @@ def parameters(cls) -> dict:
7474
"labels": ListValue(
7575
description="List of class labels", value_type=str, default_value=[]
7676
),
77+
"apply_softmax": BooleanValue(
78+
default_value=True, description="Whether to apply softmax on the heatmap."
79+
),
7780
}
7881
)
7982
return parameters
@@ -127,22 +130,25 @@ def predict_crops(self, crops: list[np.ndarray]) -> list[DetectedKeypoints]:
127130

128131

129132
def _decode_simcc(
130-
simcc_x: np.ndarray, simcc_y: np.ndarray, simcc_split_ratio: float = 2.0
133+
simcc_x: np.ndarray, simcc_y: np.ndarray, simcc_split_ratio: float = 2.0,
134+
apply_softmax: bool = False,
131135
) -> tuple[np.ndarray, np.ndarray]:
132136
"""Decodes keypoint coordinates from SimCC representations. The decoded coordinates are in the input image space.
133137
134138
Args:
135139
simcc_x (np.ndarray): SimCC label for x-axis
136140
simcc_y (np.ndarray): SimCC label for y-axis
137141
simcc_split_ratio (float): The ratio of the label size to the input size.
142+
apply_softmax (bool): whether to apply softmax on the heatmap.
143+
Defaults to False.
138144
139145
Returns:
140146
tuple:
141147
- keypoints (np.ndarray): Decoded coordinates in shape (N, K, D)
142148
- scores (np.ndarray): The keypoint scores in shape (N, K).
143149
It usually represents the confidence of the keypoint prediction
144150
"""
145-
keypoints, scores = _get_simcc_maximum(simcc_x, simcc_y)
151+
keypoints, scores = _get_simcc_maximum(simcc_x, simcc_y, apply_softmax)
146152

147153
# Unsqueeze the instance dimension for single-instance results
148154
if keypoints.ndim == 2:
@@ -157,6 +163,7 @@ def _decode_simcc(
157163
def _get_simcc_maximum(
158164
simcc_x: np.ndarray,
159165
simcc_y: np.ndarray,
166+
apply_softmax: bool = False,
160167
) -> tuple[np.ndarray, np.ndarray]:
161168
"""Get maximum response location and value from simcc representations.
162169
@@ -169,6 +176,8 @@ def _get_simcc_maximum(
169176
Args:
170177
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
171178
simcc_y (np.ndarray): y-axis SimCC in shape (K, Hy) or (N, K, Hy)
179+
apply_softmax (bool): whether to apply softmax on the heatmap.
180+
Defaults to False.
172181
173182
Returns:
174183
tuple:
@@ -194,6 +203,13 @@ def _get_simcc_maximum(
194203
else:
195204
batch_size = None
196205

206+
if apply_softmax:
207+
simcc_x = simcc_x - np.max(simcc_x, axis=1, keepdims=True)
208+
simcc_y = simcc_y - np.max(simcc_y, axis=1, keepdims=True)
209+
ex, ey = np.exp(simcc_x), np.exp(simcc_y)
210+
simcc_x = ex / np.sum(ex, axis=1, keepdims=True)
211+
simcc_y = ey / np.sum(ey, axis=1, keepdims=True)
212+
197213
x_locs = np.argmax(simcc_x, axis=1)
198214
y_locs = np.argmax(simcc_y, axis=1)
199215
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)

0 commit comments

Comments
 (0)