21
21
import numpy as np
22
22
23
23
from .image_model import ImageModel
24
- from .types import ListValue
24
+ from .types import BooleanValue , ListValue
25
25
from .utils import DetectedKeypoints , Detection
26
26
27
27
@@ -74,6 +74,9 @@ def parameters(cls) -> dict:
74
74
"labels" : ListValue (
75
75
description = "List of class labels" , value_type = str , default_value = []
76
76
),
77
+ "apply_softmax" : BooleanValue (
78
+ default_value = True , description = "Whether to apply softmax on the heatmap."
79
+ ),
77
80
}
78
81
)
79
82
return parameters
@@ -127,22 +130,25 @@ def predict_crops(self, crops: list[np.ndarray]) -> list[DetectedKeypoints]:
127
130
128
131
129
132
def _decode_simcc (
130
- simcc_x : np .ndarray , simcc_y : np .ndarray , simcc_split_ratio : float = 2.0
133
+ simcc_x : np .ndarray , simcc_y : np .ndarray , simcc_split_ratio : float = 2.0 ,
134
+ apply_softmax : bool = False ,
131
135
) -> tuple [np .ndarray , np .ndarray ]:
132
136
"""Decodes keypoint coordinates from SimCC representations. The decoded coordinates are in the input image space.
133
137
134
138
Args:
135
139
simcc_x (np.ndarray): SimCC label for x-axis
136
140
simcc_y (np.ndarray): SimCC label for y-axis
137
141
simcc_split_ratio (float): The ratio of the label size to the input size.
142
+ apply_softmax (bool): whether to apply softmax on the heatmap.
143
+ Defaults to False.
138
144
139
145
Returns:
140
146
tuple:
141
147
- keypoints (np.ndarray): Decoded coordinates in shape (N, K, D)
142
148
- scores (np.ndarray): The keypoint scores in shape (N, K).
143
149
It usually represents the confidence of the keypoint prediction
144
150
"""
145
- keypoints , scores = _get_simcc_maximum (simcc_x , simcc_y )
151
+ keypoints , scores = _get_simcc_maximum (simcc_x , simcc_y , apply_softmax )
146
152
147
153
# Unsqueeze the instance dimension for single-instance results
148
154
if keypoints .ndim == 2 :
@@ -157,6 +163,7 @@ def _decode_simcc(
157
163
def _get_simcc_maximum (
158
164
simcc_x : np .ndarray ,
159
165
simcc_y : np .ndarray ,
166
+ apply_softmax : bool = False ,
160
167
) -> tuple [np .ndarray , np .ndarray ]:
161
168
"""Get maximum response location and value from simcc representations.
162
169
@@ -169,6 +176,8 @@ def _get_simcc_maximum(
169
176
Args:
170
177
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
171
178
simcc_y (np.ndarray): y-axis SimCC in shape (K, Hy) or (N, K, Hy)
179
+ apply_softmax (bool): whether to apply softmax on the heatmap.
180
+ Defaults to False.
172
181
173
182
Returns:
174
183
tuple:
@@ -194,6 +203,13 @@ def _get_simcc_maximum(
194
203
else :
195
204
batch_size = None
196
205
206
+ if apply_softmax :
207
+ simcc_x = simcc_x - np .max (simcc_x , axis = 1 , keepdims = True )
208
+ simcc_y = simcc_y - np .max (simcc_y , axis = 1 , keepdims = True )
209
+ ex , ey = np .exp (simcc_x ), np .exp (simcc_y )
210
+ simcc_x = ex / np .sum (ex , axis = 1 , keepdims = True )
211
+ simcc_y = ey / np .sum (ey , axis = 1 , keepdims = True )
212
+
197
213
x_locs = np .argmax (simcc_x , axis = 1 )
198
214
y_locs = np .argmax (simcc_y , axis = 1 )
199
215
locs = np .stack ((x_locs , y_locs ), axis = - 1 ).astype (np .float32 )
0 commit comments