Skip to content

Commit 5eab00a

Browse files
author
Mateusz Bencer
authored
Fixed default value of score threshold (openvinotoolkit#17448)
1 parent 014eafd commit 5eab00a

File tree

3 files changed

+159
-1
lines changed

3 files changed

+159
-1
lines changed

src/frontends/onnx/frontend/src/op/non_max_suppression.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ OutputVector non_max_suppression(const Node& node) {
4444
if (ng_inputs.size() > 4 && !is_null(ng_inputs.at(4))) {
4545
score_threshold = ngraph::onnx_import::reshape::interpret_as_scalar(ng_inputs.at(4));
4646
} else {
47-
score_threshold = default_opset::Constant::create(element::f32, Shape{}, {.0f});
47+
score_threshold = default_opset::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::max()});
4848
}
4949

5050
const auto center_point_box = node.get_attribute_value<std::int64_t>("center_point_box", 0);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
ir_version: 6
2+
producer_name: "ONNX Frontend"
3+
graph {
4+
node {
5+
output: "max_output_boxes"
6+
name: "Constant_1521"
7+
op_type: "Constant"
8+
attribute {
9+
name: "value"
10+
t {
11+
dims: 1
12+
data_type: 7
13+
raw_data: "\377\377\377\377\377\377\377\177"
14+
}
15+
type: TENSOR
16+
}
17+
}
18+
node {
19+
output: "iou_threshold"
20+
name: "Constant_1522"
21+
op_type: "Constant"
22+
attribute {
23+
name: "value"
24+
t {
25+
dims: 1
26+
data_type: 1
27+
raw_data: "333?"
28+
}
29+
type: TENSOR
30+
}
31+
}
32+
node {
33+
input: "boxes"
34+
input: "scores"
35+
input: "max_output_boxes"
36+
input: "iou_threshold"
37+
output: "selected_indices"
38+
op_type: "NonMaxSuppression"
39+
}
40+
input {
41+
name: "boxes"
42+
type {
43+
tensor_type {
44+
elem_type: 1
45+
shape {
46+
dim {
47+
dim_value: 1
48+
}
49+
dim {
50+
dim_value: 50
51+
}
52+
dim {
53+
dim_value: 4
54+
}
55+
}
56+
}
57+
}
58+
}
59+
input {
60+
name: "scores"
61+
type {
62+
tensor_type {
63+
elem_type: 1
64+
shape {
65+
dim {
66+
dim_value: 1
67+
}
68+
dim {
69+
dim_value: 1
70+
}
71+
dim {
72+
dim_value: 50
73+
}
74+
}
75+
}
76+
}
77+
}
78+
output {
79+
name: "selected_indices"
80+
type {
81+
tensor_type {
82+
elem_type: 7
83+
}
84+
}
85+
}
86+
}
87+
opset_import {
88+
version: 11
89+
}

src/frontends/onnx/tests/onnx_import.in.cpp

+69
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,75 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_nonmaxsuppression_v9_single_box) {
10801080
test_case.run();
10811081
}
10821082

1083+
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_nonmaxsuppression_default_score_threshold) {
1084+
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
1085+
SERIALIZED_ZOO,
1086+
"onnx/nms_default_score_threshold.onnx"));
1087+
1088+
auto test_case = test::TestCase(function, s_device);
1089+
1090+
test_case.add_input(
1091+
Shape{1, 50, 4},
1092+
std::vector<float>(
1093+
{278.862060546875f, 453.5412902832031f, 295.09234619140625f, 470.2095031738281f, 225.9730682373047f,
1094+
387.33990478515625f, 241.69297790527344f, 403.43377685546875f, 281.3062438964844f, 453.8412170410156f,
1095+
298.6865539550781f, 470.9977111816406f, 216.9517364501953f, 450.6717529296875f, 232.95777893066406f,
1096+
466.14276123046875f, 217.54473876953125f, 449.9130859375f, 233.97265625f, 466.1539306640625f,
1097+
279.0079650878906f, 453.865234375f, 294.8210144042969f, 470.123046875f, 226.5626983642578f,
1098+
388.5235290527344f, 242.2290496826172f, 404.2589416503906f, 216.49752807617188f, 450.7710876464844f,
1099+
233.07443237304688f, 466.7010192871094f, 281.3638000488281f, 454.33892822265625f, 298.5252990722656f,
1100+
471.1678466796875f, 217.3330841064453f, 451.484130859375f, 234.1898651123047f, 466.83148193359375f,
1101+
187.2439727783203f, 466.8524475097656f, 208.7089385986328f, 489.7967224121094f, 257.8833923339844f,
1102+
515.705322265625f, 280.8927917480469f, 539.775146484375f, 226.52525329589844f, 387.7011413574219f,
1103+
241.6272430419922f, 403.7854919433594f, 187.38221740722656f, 466.5717468261719f, 209.05845642089844f,
1104+
489.4494323730469f, 217.56448364257812f, 451.1393737792969f, 233.90216064453125f, 466.1475524902344f,
1105+
279.45611572265625f, 454.00299072265625f, 296.16424560546875f, 471.84521484375f, 279.04486083984375f,
1106+
453.9889221191406f, 295.2816162109375f, 470.4144592285156f, 187.18997192382812f, 466.4650573730469f,
1107+
209.26266479492188f, 488.8149719238281f, 189.04197692871094f, 469.8923034667969f, 208.8195037841797f,
1108+
491.5357971191406f, 216.47879028320312f, 450.1073303222656f, 233.21575927734375f, 466.9475402832031f,
1109+
278.86163330078125f, 454.966552734375f, 296.38958740234375f, 471.9764404296875f, 259.4800720214844f,
1110+
515.1390991210938f, 282.3655090332031f, 539.4806518554688f, 285.031494140625f, 389.0125427246094f,
1111+
302.09747314453125f, 406.9799499511719f, 285.1270446777344f, 389.06890869140625f, 301.2108459472656f,
1112+
405.7711181640625f, 188.17117309570312f, 467.71533203125f, 208.49929809570312f, 490.401611328125f,
1113+
278.93292236328125f, 453.8080139160156f, 295.4295654296875f, 469.9015808105469f, 279.0393371582031f,
1114+
454.2393798828125f, 296.3529357910156f, 471.6363525390625f, 187.29873657226562f, 467.9837951660156f,
1115+
208.29107666015625f, 489.8014221191406f, 187.79478454589844f, 466.6510314941406f, 208.3644561767578f,
1116+
490.2976989746094f, 188.4196014404297f, 468.3448486328125f, 209.06849670410156f, 491.94384765625f,
1117+
281.4726867675781f, 454.0541687011719f, 298.2876892089844f, 470.2845764160156f, 225.8560333251953f,
1118+
387.4819030761719f, 241.4767608642578f, 403.4317321777344f, 280.7021484375f, 455.43206787109375f,
1119+
297.9931640625f, 471.99749755859375f, 226.0373077392578f, 387.4749450683594f, 241.48097229003906f,
1120+
403.4716491699219f, 259.018310546875f, 515.3871459960938f, 281.7872314453125f, 540.0093383789062f,
1121+
217.71246337890625f, 450.4556884765625f, 234.254150390625f, 467.68182373046875f, 257.5479736328125f,
1122+
518.8912353515625f, 280.48260498046875f, 541.3863525390625f, 216.87359619140625f, 450.3395080566406f,
1123+
232.39752197265625f, 465.5039367675781f, 258.2445068359375f, 515.2009887695312f, 280.29803466796875f,
1124+
540.3602905273438f, 217.54478454589844f, 451.3944091796875f, 233.6602020263672f, 467.51971435546875f,
1125+
258.30133056640625f, 515.2357788085938f, 280.1400146484375f, 541.3275756835938f, 217.05136108398438f,
1126+
451.8975524902344f, 232.9573974609375f, 466.9907531738281f, 215.86386108398438f, 450.801025390625f,
1127+
232.117919921875f, 466.3701171875f, 279.01593017578125f, 453.6647644042969f, 296.13372802734375f,
1128+
471.4644470214844f, 280.1851806640625f, 454.41900634765625f, 296.481201171875f, 471.63104248046875f,
1129+
259.1214904785156f, 516.8644409179688f, 281.7276306152344f, 541.0162963867188f, 285.2935485839844f,
1130+
389.03515625f, 302.1134948730469f, 406.89373779296875f, 279.6715393066406f, 455.1846923828125f,
1131+
296.6995544433594f, 471.5782470703125f, 258.1405029296875f, 518.9312744140625f, 281.019287109375f,
1132+
541.5760498046875f, 187.80953979492188f, 466.8480224609375f, 208.54336547851562f, 489.9696044921875f}));
1133+
test_case.add_input(
1134+
Shape{1, 1, 50},
1135+
std::vector<float>(
1136+
{5.485373497009277f, 5.469169616699219f, 5.450349807739258f, 5.446445465087891f, 5.43833065032959f,
1137+
5.407294273376465f, 5.3790669441223145f, 5.3575520515441895f, 5.348986625671387f, 5.309826850891113f,
1138+
5.266261577606201f, 5.230800151824951f, 5.079848766326904f, 5.066829204559326f, 4.913329601287842f,
1139+
4.895563125610352f, 4.8786115646362305f, 4.872953414916992f, 4.825906753540039f, 4.812736511230469f,
1140+
4.761179447174072f, 4.657320022583008f, 4.640903949737549f, 4.63286828994751f, 4.600266933441162f,
1141+
4.599870204925537f, 4.5536088943481445f, 4.521742820739746f, 4.465426445007324f, 4.4556074142456055f,
1142+
4.451722621917725f, 4.416017055511475f, 4.410635471343994f, 4.403003215789795f, 4.387508392333984f,
1143+
4.3634934425354f, 4.362300872802734f, 4.348748683929443f, 4.345107555389404f, 4.32416296005249f,
1144+
4.3132781982421875f, 4.287333965301514f, 4.223401069641113f, 4.220005035400391f, 4.179988861083984f,
1145+
4.099865436553955f, 4.097578048706055f, 4.075544357299805f, 4.0459885597229f}));
1146+
1147+
test_case.add_expected_output<int64_t>(Shape{7, 3},
1148+
{0, 0, 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 10, 0, 0, 11, 0, 0, 22});
1149+
test_case.run();
1150+
}
1151+
10831152
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_log_sum) {
10841153
auto function = onnx_import::import_onnx_model(
10851154
file_util::path_join(CommonTestUtils::getExecutableDirectory(), SERIALIZED_ZOO, "onnx/reduce_log_sum.onnx"));

0 commit comments

Comments
 (0)