1
1
#
2
- # Copyright (C) 2020-2024 Intel Corporation
2
+ # Copyright (C) 2020-2025 Intel Corporation
3
3
# SPDX-License-Identifier: Apache-2.0
4
4
#
5
5
11
11
12
12
from .image_model import ImageModel
13
13
from .result import DetectedKeypoints , DetectionResult
14
- from .types import ListValue
14
+ from .types import BooleanValue , ListValue
15
15
16
16
17
17
class KeypointDetectionModel (ImageModel ):
@@ -30,6 +30,7 @@ def __init__(self, inference_adapter, configuration: dict = {}, preload=False):
30
30
"""
31
31
super ().__init__ (inference_adapter , configuration , preload )
32
32
self ._check_io_number (1 , 2 )
33
+ self .apply_softmax : bool
33
34
34
35
def postprocess (
35
36
self ,
@@ -46,7 +47,11 @@ def postprocess(
46
47
DetectedKeypoints: detected keypoints
47
48
"""
48
49
encoded_kps = list (outputs .values ())
49
- batch_keypoints , batch_scores = _decode_simcc (* encoded_kps )
50
+ batch_keypoints , batch_scores = _decode_simcc (
51
+ encoded_kps [0 ],
52
+ encoded_kps [1 ],
53
+ apply_softmax = self .apply_softmax ,
54
+ )
50
55
orig_h , orig_w = meta ["original_shape" ][:2 ]
51
56
kp_scale_h = orig_h / self .h
52
57
kp_scale_w = orig_w / self .w
@@ -63,6 +68,10 @@ def parameters(cls) -> dict:
63
68
value_type = str ,
64
69
default_value = [],
65
70
),
71
+ "apply_softmax" : BooleanValue (
72
+ default_value = True ,
73
+ description = "Whether to apply softmax on the heatmap." ,
74
+ ),
66
75
},
67
76
)
68
77
return parameters
@@ -119,21 +128,24 @@ def _decode_simcc(
119
128
simcc_x : np .ndarray ,
120
129
simcc_y : np .ndarray ,
121
130
simcc_split_ratio : float = 2.0 ,
131
+ apply_softmax : bool = False ,
122
132
) -> tuple [np .ndarray , np .ndarray ]:
123
133
"""Decodes keypoint coordinates from SimCC representations. The decoded coordinates are in the input image space.
124
134
125
135
Args:
126
136
simcc_x (np.ndarray): SimCC label for x-axis
127
137
simcc_y (np.ndarray): SimCC label for y-axis
128
138
simcc_split_ratio (float): The ratio of the label size to the input size.
139
+ apply_softmax (bool): whether to apply softmax on the heatmap.
140
+ Defaults to False.
129
141
130
142
Returns:
131
143
tuple:
132
144
- keypoints (np.ndarray): Decoded coordinates in shape (N, K, D)
133
145
- scores (np.ndarray): The keypoint scores in shape (N, K).
134
146
It usually represents the confidence of the keypoint prediction
135
147
"""
136
- keypoints , scores = _get_simcc_maximum (simcc_x , simcc_y )
148
+ keypoints , scores = _get_simcc_maximum (simcc_x , simcc_y , apply_softmax )
137
149
138
150
# Unsqueeze the instance dimension for single-instance results
139
151
if keypoints .ndim == 2 :
@@ -148,6 +160,8 @@ def _decode_simcc(
148
160
def _get_simcc_maximum (
149
161
simcc_x : np .ndarray ,
150
162
simcc_y : np .ndarray ,
163
+ apply_softmax : bool = False ,
164
+ softmax_eps : float = 1e-06 ,
151
165
) -> tuple [np .ndarray , np .ndarray ]:
152
166
"""Get maximum response location and value from simcc representations.
153
167
@@ -160,6 +174,10 @@ def _get_simcc_maximum(
160
174
Args:
161
175
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
162
176
simcc_y (np.ndarray): y-axis SimCC in shape (K, Hy) or (N, K, Hy)
177
+ apply_softmax (bool): whether to apply softmax on the heatmap.
178
+ Defaults to False.
179
+ softmax_eps (flat): a constant to avoid division by zero in softmax.
180
+ Defaults to 1e-6.
163
181
164
182
Returns:
165
183
tuple:
@@ -185,6 +203,13 @@ def _get_simcc_maximum(
185
203
else :
186
204
batch_size = None
187
205
206
+ if apply_softmax :
207
+ simcc_x = simcc_x - np .max (simcc_x , axis = 1 , keepdims = True )
208
+ simcc_y = simcc_y - np .max (simcc_y , axis = 1 , keepdims = True )
209
+ ex , ey = np .exp (simcc_x ), np .exp (simcc_y )
210
+ simcc_x = ex / (np .sum (ex , axis = 1 , keepdims = True ) + softmax_eps )
211
+ simcc_y = ey / (np .sum (ey , axis = 1 , keepdims = True ) + softmax_eps )
212
+
188
213
x_locs = np .argmax (simcc_x , axis = 1 )
189
214
y_locs = np .argmax (simcc_y , axis = 1 )
190
215
locs = np .stack ((x_locs , y_locs ), axis = - 1 ).astype (np .float32 )
0 commit comments