Skip to content

Commit 88b792e

Browse files
pavel-esirrkazants
andauthored
[TF FE] Fix centernet and correct FloorDiv translator for signed integer type (openvinotoolkit#22684)
### Details: - Centernet's topk operation returns large int32 values (greater than 1000000), even though they are integer `FloorDiv`/`Div(inp_1, inp_2) + Floor` operation is performed in float16 and because of that it causes accuracy problems. - To solve this need to performs FloorDiv operation in integer with a subgraph: ``` res = x / y; if x > 0 and y > 0 res = x / y - 1; if (x < 0 xor y < 0) and (x mod y != 0) ``` - checked on separate bus: no degradations caused. ### Tickets: - CVS-130526 --------- Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
1 parent b53fa91 commit 88b792e

File tree

4 files changed

+108
-11
lines changed

4 files changed

+108
-11
lines changed

src/frontends/tensorflow_common/src/op/binary_op.cpp

+24-5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "openvino/op/bitwise_and.hpp"
99
#include "openvino/op/bitwise_or.hpp"
1010
#include "openvino/op/bitwise_xor.hpp"
11+
#include "openvino/op/ceiling.hpp"
1112
#include "openvino/op/concat.hpp"
1213
#include "openvino/op/convert.hpp"
1314
#include "openvino/op/divide.hpp"
@@ -26,9 +27,11 @@
2627
#include "openvino/op/minimum.hpp"
2728
#include "openvino/op/mod.hpp"
2829
#include "openvino/op/multiply.hpp"
30+
#include "openvino/op/negative.hpp"
2931
#include "openvino/op/not_equal.hpp"
3032
#include "openvino/op/power.hpp"
3133
#include "openvino/op/prelu.hpp"
34+
#include "openvino/op/select.hpp"
3235
#include "openvino/op/squared_difference.hpp"
3336
#include "openvino/op/subtract.hpp"
3437
#include "openvino/op/unsqueeze.hpp"
@@ -54,11 +57,27 @@ OutputVector translate_binary_op(const NodeContext& node,
5457
OutputVector translate_floor_div_op(const NodeContext& node) {
5558
auto floordiv_fn = [](const Output<Node>& x, const Output<Node>& y) -> shared_ptr<Node> {
5659
auto out_type = x.get_element_type();
57-
if (out_type.is_integral()) {
58-
auto float_x = make_shared<v0::Convert>(x, element::f32);
59-
auto float_y = make_shared<v0::Convert>(y, element::f32);
60-
return make_shared<v0::Convert>(make_shared<v0::Floor>(make_shared<v1::Divide>(float_x, float_y)),
61-
out_type);
60+
if (out_type.is_integral() && out_type.is_signed()) {
61+
// when integer inputs have different signs remainder should be taken into account
62+
// res = x / y; if x > 0 and y > 0
63+
// res = x / y - 1; if (x < 0 xor y < 0) and (x mod y != 0)
64+
65+
auto zero_const = make_shared<v0::Constant>(out_type, Shape{}, 0);
66+
auto minus_one_const = make_shared<v0::Constant>(out_type, Shape{}, -1);
67+
68+
auto x_less_cond = make_shared<v1::Less>(x, zero_const);
69+
auto y_less_cond = make_shared<v1::Less>(y, zero_const);
70+
auto xor_cond = make_shared<v1::LogicalXor>(x_less_cond, y_less_cond);
71+
72+
auto div = make_shared<v1::Divide>(x, y, false);
73+
auto mod_xy = make_shared<v1::Mod>(x, y);
74+
auto cond_mod = make_shared<v1::NotEqual>(mod_xy, zero_const);
75+
76+
auto cond = make_shared<v1::LogicalAnd>(cond_mod, xor_cond);
77+
auto reminder = make_shared<v1::Select>(cond, minus_one_const, zero_const);
78+
return make_shared<v1::Add>(div, reminder);
79+
} else if (out_type.is_integral() && !out_type.is_signed()) {
80+
return make_shared<v1::Divide>(x, y);
6281
} else {
6382
return make_shared<v0::Floor>(make_shared<v1::Divide>(x, y));
6483
}

tests/layer_tests/common/utils/common_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def generate_ir_python_api(coverage=False, **kwargs):
5353

5454
out_dir = kwargs['output_dir'] + os.sep + kwargs['model_name'] + ".xml"
5555

56-
# TODO: Remove usage of legacy params from layer tests and switch to convert_model from tools.ovc
56+
# TODO: CVS-132151 Remove usage of legacy params from layer tests and switch to convert_model from tools.ovc
5757
ov_model = convert_model(**kwargs)
5858
serialize(ov_model, out_dir)
5959

tests/layer_tests/tensorflow_tests/test_tf_Div.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def create_div_net(self, input_shape, input_type):
3636
dict(input_shape=[10, 20], input_type=np.float32),
3737
dict(input_shape=[2, 3, 4], input_type=np.float32),
3838
pytest.param(dict(input_shape=[8, 5], input_type=np.int32),
39-
marks=pytest.mark.xfail(reason='Ticket TBD - Divide inconsistent behavior on different systems')),
39+
marks=pytest.mark.xfail(reason='Ticket CVS-132377 - Divide inconsistent behavior on different systems')),
4040
dict(input_shape=[], input_type=np.float32),
4141
]
4242

@@ -47,4 +47,4 @@ def test_div_basic(self, params, ie_device, precision, ir_version, temp_dir,
4747
use_new_frontend):
4848
self._test(*self.create_div_net(**params),
4949
ie_device, precision, ir_version, temp_dir=temp_dir,
50-
use_new_frontend=use_new_frontend)
50+
use_new_frontend=use_new_frontend)

tests/layer_tests/tensorflow_tests/test_tf_FloorDiv.py

+81-3
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,26 @@
33

44
import numpy as np
55
import pytest
6+
import platform
67

78
from common.tf_layer_test_class import CommonTFLayerTest
8-
from common.utils.tf_utils import permute_nchw_to_nhwc
99

10+
rng = np.random.default_rng()
11+
12+
def list_arm_platforms():
13+
return ['arm', 'armv7l', 'aarch64', 'arm64', 'ARM64']
1014

1115
class TestFloorDiv(CommonTFLayerTest):
1216
def create_add_placeholder_const_net(self, x_shape, dtype, ir_version, use_new_frontend):
1317
import tensorflow as tf
14-
18+
self.dtype = dtype
1519
tf.compat.v1.reset_default_graph()
1620

1721
# Create the graph and model
1822
with tf.compat.v1.Session() as sess:
1923
x = tf.compat.v1.placeholder(dtype, x_shape, 'Input')
2024
constant_value = np.array(-10).astype(dtype)
2125
y = tf.constant(constant_value)
22-
x = tf.raw_ops.Abs(x=x)
2326
res = tf.raw_ops.FloorDiv(x=x, y=y)
2427

2528
tf.compat.v1.global_variables_initializer()
@@ -29,12 +32,28 @@ def create_add_placeholder_const_net(self, x_shape, dtype, ir_version, use_new_f
2932

3033
return tf_net, ref_net
3134

35+
def _prepare_input(self, inputs_info):
36+
tensor_name = list(inputs_info.keys())[0]
37+
x_shape = inputs_info[tensor_name]
38+
inputs_data = {}
39+
if np.issubdtype(self.dtype, np.floating):
40+
inputs_data[tensor_name] = rng.uniform(-5.0, 5.0, x_shape).astype(self.dtype)
41+
elif np.issubdtype(self.dtype, np.signedinteger):
42+
inputs_data[tensor_name] = rng.integers(-8, 8, x_shape).astype(self.dtype)
43+
else:
44+
inputs_data[tensor_name] = rng.integers(0, 8, x_shape).astype(self.dtype)
45+
return inputs_data
46+
3247
# TODO: implement tests for 2 Consts + Add
3348

49+
3450
test_data_1D = [
3551
dict(x_shape=[], dtype=np.int32),
3652
dict(x_shape=[2], dtype=np.int64),
3753
dict(x_shape=[2, 4, 5], dtype=np.int32),
54+
dict(x_shape=[2, 4, 5], dtype=np.uint32),
55+
dict(x_shape=[2, 4, 5], dtype=np.uint64),
56+
3857
dict(x_shape=[], dtype=np.float32),
3958
dict(x_shape=[2], dtype=np.float64),
4059
dict(x_shape=[2, 4, 5], dtype=np.float32),
@@ -45,7 +64,66 @@ def create_add_placeholder_const_net(self, x_shape, dtype, ir_version, use_new_f
4564
@pytest.mark.precommit_tf_fe
4665
def test_add_placeholder_const_1D(self, params, ie_device, precision, ir_version, temp_dir,
4766
use_new_frontend):
67+
if platform.system() == 'Linux' and platform.machine() in list_arm_platforms() and np.issubdtype(params['dtype'], np.signedinteger):
68+
pytest.xfail(reason='Ticket CVS-132377 - Divide inconsistent behavior on different systems')
69+
4870
self._test(*self.create_add_placeholder_const_net(**params, ir_version=ir_version,
4971
use_new_frontend=use_new_frontend),
5072
ie_device, precision, ir_version, temp_dir=temp_dir,
5173
use_new_frontend=use_new_frontend)
74+
75+
76+
class TestFloorDivStaticInput(CommonTFLayerTest):
77+
min = -100
78+
max = 200
79+
step = 1
80+
dtype = np.int32
81+
82+
def create_flordiv_tf_net(self, min, max, step, y, dtype, ir_version, use_new_frontend):
83+
import tensorflow as tf
84+
x = np.arange(min, max, step, dtype=dtype)
85+
86+
self.min = min
87+
self.max = max
88+
self.step = step
89+
self.dtype = dtype
90+
91+
tf.compat.v1.reset_default_graph()
92+
93+
with tf.compat.v1.Session() as sess:
94+
x = tf.compat.v1.placeholder(dtype, x.shape, 'Input')
95+
y = tf.constant(np.array(y).astype(dtype))
96+
res = tf.raw_ops.FloorDiv(x=x, y=y)
97+
98+
tf.compat.v1.global_variables_initializer()
99+
tf_net = sess.graph_def
100+
101+
ref_net = None
102+
103+
return tf_net, ref_net
104+
105+
def _prepare_input(self, inputs_dict):
106+
for input in inputs_dict.keys():
107+
inputs_dict[input] = np.arange(self.min, self.max, self.step, dtype=self.dtype)
108+
return inputs_dict
109+
110+
test_inputs = [
111+
dict(min=-20, max=20, step=1, y=[10]),
112+
dict(min=-20, max=20, step=1, y=[5]),
113+
dict(min=-20, max=20, step=1, y=[6]),
114+
dict(min=-20, max=20, step=1, y=[-5]),
115+
dict(min=-20, max=20, step=1, y=[-6]),
116+
dict(min=-1e5, max=1e5, step=100, y=[1e5]),
117+
]
118+
@pytest.mark.parametrize("params", test_inputs)
119+
@pytest.mark.parametrize("dtype", [np.int32, np.int64])
120+
@pytest.mark.nightly
121+
@pytest.mark.precommit_tf_fe
122+
@pytest.mark.xfail(condition=platform.system() == 'Linux' and platform.machine() in list_arm_platforms(),
123+
reason='Ticket CVS-132377 - Divide inconsistent behavior on different systems')
124+
def test_floordiv(self, params, dtype, ie_device, precision, ir_version, temp_dir,
125+
use_new_frontend):
126+
self._test(*self.create_flordiv_tf_net(**params, dtype=dtype, ir_version=ir_version,
127+
use_new_frontend=use_new_frontend),
128+
ie_device, precision, ir_version, temp_dir=temp_dir,
129+
use_new_frontend=use_new_frontend)

0 commit comments

Comments
 (0)