Skip to content

Commit 19e7dae

Browse files
authored
[TF FE] Stabilize tests for unary operation defined on full real domain (openvinotoolkit#28111)
**Details:** Stabilize tests for unary operation defined on full real domain **Ticket:** TBD --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
1 parent 706f340 commit 19e7dae

File tree

2 files changed

+78
-43
lines changed

2 files changed

+78
-43
lines changed

tests/layer_tests/tensorflow_tests/test_tf_UnaryOps.py

+8-43
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# Copyright (C) 2018-2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4-
import platform
5-
import sys
6-
74
import numpy as np
5+
import platform
86
import pytest
7+
import sys
98
from common.tf_layer_test_class import CommonTFLayerTest
109

1110

@@ -14,7 +13,7 @@ class TestUnaryOps(CommonTFLayerTest):
1413

1514
def _prepare_input(self, inputs_dict):
1615
non_negative = ['Sqrt', 'Log']
17-
narrow_borders = ["Sinh", "Cosh", "Tanh", "Exp", "Selu"]
16+
narrow_borders = ["Tanh"]
1817
within_one = ['Asin', 'Acos', 'Atanh']
1918
from_one = ['Acosh']
2019

@@ -76,25 +75,14 @@ def create_net_with_unary_op(self, shape, ir_version, op_type, use_legacy_fronte
7675
'Asin': tf.math.asin,
7776
'Asinh': tf.math.asinh,
7877
'Atan': tf.math.atan,
79-
'Atanh': tf.math.atanh,
8078
'BitwiseNot': tf.bitwise.invert,
8179
'Ceiling': tf.math.ceil,
82-
'Cos': tf.math.cos,
83-
'Cosh': tf.math.cosh,
84-
'Elu': tf.nn.elu,
85-
'Erf': tf.math.erf,
86-
'Exp': tf.math.exp,
8780
'Floor': tf.math.floor,
8881
'Log': tf.math.log,
8982
'LogicalNot': tf.math.logical_not,
9083
# 'Mish': tfa.activations.mish, # temporarily moved to `create_net_with_mish()`
9184
'Negative': tf.math.negative,
92-
'Selu': tf.nn.selu,
93-
'Sigmoid': tf.nn.sigmoid,
9485
'Sign': tf.math.sign,
95-
'Sin': tf.math.sin,
96-
'Sinh': tf.math.sinh,
97-
'SoftPlus': tf.nn.softplus,
9886
'Square': tf.math.square,
9987
'Tan': tf.math.tan,
10088
'Tanh': tf.math.tanh,
@@ -126,15 +114,8 @@ def create_net_with_unary_op(self, shape, ir_version, op_type, use_legacy_fronte
126114
test_data_precommit = [dict(shape=[4, 6, 8, 10, 12])]
127115

128116
@pytest.mark.parametrize("params", test_data_precommit)
129-
@pytest.mark.parametrize("op_type", ['Elu',
130-
'Sigmoid',
131-
'Sin',
132-
'Sinh',
133-
'Cos',
134-
'Cosh',
135-
'Abs',
117+
@pytest.mark.parametrize("op_type", ['Abs',
136118
'Negative',
137-
'Exp',
138119
'Tan',
139120
'Tanh',
140121
'Floor',
@@ -145,15 +126,11 @@ def create_net_with_unary_op(self, shape, ir_version, op_type, use_legacy_fronte
145126
'Atan',
146127
'Log',
147128
'Sign',
148-
'SoftPlus',
149-
'Atanh',
150129
'Acosh',
151130
'Asinh',
152131
'LogicalNot',
153132
'Square',
154-
'Erf',
155-
'BitwiseNot'
156-
])
133+
'BitwiseNot'])
157134
@pytest.mark.nightly
158135
def test_unary_op_precommit(self, params, ie_device, precision, ir_version, temp_dir, op_type,
159136
use_legacy_frontend):
@@ -188,15 +165,8 @@ def test_unary_op_mish_precommit(self, params, ie_device, precision, ir_version,
188165
dict(shape=[4, 6, 8, 10, 12])]
189166

190167
@pytest.mark.parametrize("params", test_data)
191-
@pytest.mark.parametrize("op_type", ['Elu',
192-
'Sigmoid',
193-
'Sin',
194-
'Sinh',
195-
'Cos',
196-
'Cosh',
197-
'Abs',
168+
@pytest.mark.parametrize("op_type", ['Abs',
198169
'Negative',
199-
'Exp',
200170
'Tan',
201171
'Tanh',
202172
'Floor',
@@ -206,17 +176,12 @@ def test_unary_op_mish_precommit(self, params, ie_device, precision, ir_version,
206176
'Acos',
207177
'Atan',
208178
'Log',
209-
'LogicalNot',
210179
'Sign',
211-
'SoftPlus',
212-
'Atanh',
213180
'Acosh',
214181
'Asinh',
182+
'LogicalNot',
215183
'Square',
216-
'Erf',
217-
'Selu',
218-
'BitwiseNot'
219-
])
184+
'BitwiseNot'])
220185
@pytest.mark.nightly
221186
@pytest.mark.skipif(sys.platform == 'darwin', reason="Ticket - 122182")
222187
@pytest.mark.xfail(platform.machine() in ["aarch64", "arm64", "ARM64"], reason='Ticket - 122716')
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (C) 2018-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import numpy as np
5+
import platform
6+
import pytest
7+
import tensorflow as tf
8+
from common.tf_layer_test_class import CommonTFLayerTest
9+
10+
rng = np.random.default_rng(253512)
11+
12+
13+
class TestUnaryOpsAllRealDomain(CommonTFLayerTest):
14+
def _prepare_input(self, inputs_info):
15+
assert 'x:0' in inputs_info, "Test error: inputs_info must contain `x`"
16+
x_shape = inputs_info['x:0']
17+
inputs_data = {}
18+
inputs_data['x:0'] = rng.uniform(-5.0, 5.0, x_shape).astype(self.input_type)
19+
return inputs_data
20+
21+
def create_unary_net(self, input_shape, input_type, op_type):
22+
op_type_map = {
23+
'Elu': lambda x: tf.raw_ops.Elu(features=x),
24+
'Sigmoid': tf.raw_ops.Sigmoid,
25+
'Sin': tf.raw_ops.Sin,
26+
'Sinh': tf.raw_ops.Sinh,
27+
'Cos': tf.raw_ops.Cos,
28+
'Cosh': tf.raw_ops.Cosh,
29+
'Exp': tf.raw_ops.Exp,
30+
'Atan': tf.raw_ops.Atan,
31+
'Softplus': lambda x: tf.raw_ops.Softplus(features=x),
32+
'Erf': tf.raw_ops.Erf,
33+
'Selu': lambda x: tf.raw_ops.Selu(features=x)
34+
}
35+
36+
self.input_type = input_type
37+
tf.compat.v1.reset_default_graph()
38+
# Create the graph and model
39+
with tf.compat.v1.Session() as sess:
40+
x = tf.compat.v1.placeholder(input_type, input_shape, 'x')
41+
op_type_map[op_type](x=x)
42+
tf.compat.v1.global_variables_initializer()
43+
44+
tf_net = sess.graph_def
45+
46+
return tf_net, None
47+
48+
@pytest.mark.parametrize("input_shape", [[], [2], [3, 4], [3, 2, 4]])
49+
@pytest.mark.parametrize("input_type", [np.float16, np.float32, np.float64])
50+
@pytest.mark.parametrize("op_type", ['Elu',
51+
'Sigmoid',
52+
'Sin',
53+
'Sinh',
54+
'Cos',
55+
'Cosh',
56+
'Exp',
57+
'Atan',
58+
'Softplus',
59+
'Erf',
60+
'Selu'])
61+
@pytest.mark.precommit
62+
@pytest.mark.nightly
63+
def test_unary_ops(self, input_shape, input_type, op_type,
64+
ie_device, precision, ir_version, temp_dir,
65+
use_legacy_frontend):
66+
if platform.machine() in ["aarch64", "arm64", "ARM64"] and op_type in ['Cos', 'Cosh', 'Sinh', 'Exp']:
67+
pytest.skip("159585: accuracy error on ARM")
68+
self._test(*self.create_unary_net(input_shape, input_type, op_type),
69+
ie_device, precision, ir_version, temp_dir=temp_dir,
70+
use_legacy_frontend=use_legacy_frontend, custom_eps=1e-3)

0 commit comments

Comments
 (0)