|
| 1 | +# Copyright (C) 2018-2024 Intel Corporation |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import pytest |
| 6 | +import tensorflow as tf |
| 7 | +from common.tf_layer_test_class import CommonTFLayerTest |
| 8 | + |
| 9 | +rng = np.random.default_rng() |
| 10 | + |
| 11 | + |
| 12 | +class TestStaticRegexReplace(CommonTFLayerTest): |
| 13 | + def _prepare_input(self, inputs_info): |
| 14 | + assert 'input' in inputs_info |
| 15 | + input_shape = inputs_info['input'] |
| 16 | + inputs_data = {} |
| 17 | + strings_dictionary = ['UPPER CASE SENTENCE', 'lower case sentence', ' UppEr LoweR CAse SENtence \t\n', ' ', |
| 18 | + 'Oferta polska', 'Предложение по-РУССки', '汉语句子'] |
| 19 | + inputs_data['input'] = rng.choice(strings_dictionary, input_shape) |
| 20 | + return inputs_data |
| 21 | + |
| 22 | + def create_static_regex_replace_net(self, input_shape, pattern, rewrite, replace_global): |
| 23 | + self.pattern = pattern |
| 24 | + |
| 25 | + tf.compat.v1.reset_default_graph() |
| 26 | + with tf.compat.v1.Session() as sess: |
| 27 | + input = tf.compat.v1.placeholder(tf.string, input_shape, 'input') |
| 28 | + tf.raw_ops.StaticRegexReplace(input=input, pattern=pattern, rewrite=rewrite, replace_global=replace_global) |
| 29 | + tf.compat.v1.global_variables_initializer() |
| 30 | + tf_net = sess.graph_def |
| 31 | + |
| 32 | + ref_net = None |
| 33 | + |
| 34 | + return tf_net, ref_net |
| 35 | + |
| 36 | + @pytest.mark.parametrize('input_shape', [[], [2], [3, 4], [1, 3, 2]]) |
| 37 | + @pytest.mark.parametrize('pattern', ['(\s)|(-)', '[A-Z]{2,}', '^\s+|\s+$']) |
| 38 | + @pytest.mark.parametrize('rewrite', ['', 'replacement word']) |
| 39 | + @pytest.mark.parametrize('replace_global', [None, True, False]) |
| 40 | + @pytest.mark.precommit_tf_fe |
| 41 | + @pytest.mark.nightly |
| 42 | + @pytest.mark.xfail(reason='132674 - Add support of StaticRegexReplace') |
| 43 | + def test_static_regex_replace(self, input_shape, pattern, rewrite, replace_global, |
| 44 | + ie_device, precision, ir_version, temp_dir, |
| 45 | + use_legacy_frontend): |
| 46 | + self._test(*self.create_static_regex_replace_net(input_shape=input_shape, pattern=pattern, rewrite=rewrite, |
| 47 | + replace_global=replace_global), |
| 48 | + ie_device, precision, ir_version, temp_dir=temp_dir, |
| 49 | + use_legacy_frontend=use_legacy_frontend) |
0 commit comments