Skip to content

Commit fa43c7b

Browse files
authored
[TF FE] Add layer test for tf.raw_ops.StaticRegexReplace operation (openvinotoolkit#22973)
**Details:** Add layer test for tf.raw_ops.StaticRegexReplace operation **Ticket:** 132910 Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
1 parent fd78fb2 commit fa43c7b

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (C) 2018-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import numpy as np
5+
import pytest
6+
import tensorflow as tf
7+
from common.tf_layer_test_class import CommonTFLayerTest
8+
9+
rng = np.random.default_rng()
10+
11+
12+
class TestStaticRegexReplace(CommonTFLayerTest):
13+
def _prepare_input(self, inputs_info):
14+
assert 'input' in inputs_info
15+
input_shape = inputs_info['input']
16+
inputs_data = {}
17+
strings_dictionary = ['UPPER CASE SENTENCE', 'lower case sentence', ' UppEr LoweR CAse SENtence \t\n', ' ',
18+
'Oferta polska', 'Предложение по-РУССки', '汉语句子']
19+
inputs_data['input'] = rng.choice(strings_dictionary, input_shape)
20+
return inputs_data
21+
22+
def create_static_regex_replace_net(self, input_shape, pattern, rewrite, replace_global):
23+
self.pattern = pattern
24+
25+
tf.compat.v1.reset_default_graph()
26+
with tf.compat.v1.Session() as sess:
27+
input = tf.compat.v1.placeholder(tf.string, input_shape, 'input')
28+
tf.raw_ops.StaticRegexReplace(input=input, pattern=pattern, rewrite=rewrite, replace_global=replace_global)
29+
tf.compat.v1.global_variables_initializer()
30+
tf_net = sess.graph_def
31+
32+
ref_net = None
33+
34+
return tf_net, ref_net
35+
36+
@pytest.mark.parametrize('input_shape', [[], [2], [3, 4], [1, 3, 2]])
37+
@pytest.mark.parametrize('pattern', ['(\s)|(-)', '[A-Z]{2,}', '^\s+|\s+$'])
38+
@pytest.mark.parametrize('rewrite', ['', 'replacement word'])
39+
@pytest.mark.parametrize('replace_global', [None, True, False])
40+
@pytest.mark.precommit_tf_fe
41+
@pytest.mark.nightly
42+
@pytest.mark.xfail(reason='132674 - Add support of StaticRegexReplace')
43+
def test_static_regex_replace(self, input_shape, pattern, rewrite, replace_global,
44+
ie_device, precision, ir_version, temp_dir,
45+
use_legacy_frontend):
46+
self._test(*self.create_static_regex_replace_net(input_shape=input_shape, pattern=pattern, rewrite=rewrite,
47+
replace_global=replace_global),
48+
ie_device, precision, ir_version, temp_dir=temp_dir,
49+
use_legacy_frontend=use_legacy_frontend)

0 commit comments

Comments
 (0)