|
| 1 | +# Copyright (C) 2018-2024 Intel Corporation |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import pytest |
| 6 | +import tensorflow as tf |
| 7 | +from common.tf_layer_test_class import CommonTFLayerTest |
| 8 | + |
| 9 | +rng = np.random.default_rng(872173) |
| 10 | + |
| 11 | + |
| 12 | +class TestTensorScatterAdd(CommonTFLayerTest): |
| 13 | + def _prepare_input(self, inputs_info): |
| 14 | + assert 'tensor:0' in inputs_info |
| 15 | + assert 'indices:0' in inputs_info |
| 16 | + assert 'updates:0' in inputs_info |
| 17 | + |
| 18 | + tensor_shape = inputs_info['tensor:0'] |
| 19 | + updates_shape = inputs_info['updates:0'] |
| 20 | + indices_shape = inputs_info['indices:0'] |
| 21 | + |
| 22 | + inputs_data = {} |
| 23 | + if np.issubdtype(self.data_type, np.floating): |
| 24 | + inputs_data['tensor:0'] = rng.uniform(-5.0, 5.0, tensor_shape).astype(self.data_type) |
| 25 | + inputs_data['updates:0'] = rng.uniform(-5.0, 5.0, updates_shape).astype(self.data_type) |
| 26 | + elif np.issubdtype(self.data_type, np.signedinteger): |
| 27 | + inputs_data['tensor:0'] = rng.integers(-8, 8, tensor_shape).astype(self.data_type) |
| 28 | + inputs_data['updates:0'] = rng.integers(-8, 8, updates_shape).astype(self.data_type) |
| 29 | + else: |
| 30 | + inputs_data['tensor:0'] = rng.integers(0, 8, tensor_shape).astype(self.data_type) |
| 31 | + inputs_data['updates:0'] = rng.integers(0, 8, updates_shape).astype(self.data_type) |
| 32 | + |
| 33 | + indices_rows, indices_col = indices_shape |
| 34 | + |
| 35 | + indices_of_tensor_shape = [] |
| 36 | + for i in range(0, indices_col): |
| 37 | + indices_of_tensor_shape.append(np.arange(tensor_shape[i])) |
| 38 | + |
| 39 | + mesh = np.meshgrid(*indices_of_tensor_shape) |
| 40 | + |
| 41 | + all_indicies = np.stack(mesh, axis=indices_col) |
| 42 | + all_indicies = all_indicies.reshape(-1, all_indicies.shape[-1]) |
| 43 | + |
| 44 | + inputs_data['indices:0'] = rng.choice(all_indicies, indices_rows, replace=False).astype(self.indices_type) |
| 45 | + |
| 46 | + return inputs_data |
| 47 | + |
| 48 | + def create_tensor_scatter_add_net(self, data_type, indices_type, tensor_shape, updates_shape, indices_shape): |
| 49 | + self.data_type = data_type |
| 50 | + self.indices_type = indices_type |
| 51 | + self.tensor_shape = tensor_shape |
| 52 | + self.updates_shape = updates_shape |
| 53 | + self.indices_shape = indices_shape |
| 54 | + tf.compat.v1.reset_default_graph() |
| 55 | + with tf.compat.v1.Session() as sess: |
| 56 | + indices = tf.compat.v1.placeholder(indices_type, indices_shape, 'indices') |
| 57 | + tensor = tf.compat.v1.placeholder(data_type, tensor_shape, 'tensor') |
| 58 | + updates = tf.compat.v1.placeholder(data_type, updates_shape, 'updates') |
| 59 | + tf.raw_ops.TensorScatterAdd( |
| 60 | + tensor=tensor, |
| 61 | + indices=indices, |
| 62 | + updates=updates) |
| 63 | + tf.compat.v1.global_variables_initializer() |
| 64 | + tf_net = sess.graph_def |
| 65 | + |
| 66 | + ref_net = None |
| 67 | + |
| 68 | + return tf_net, ref_net |
| 69 | + |
| 70 | + @pytest.mark.parametrize('data_type', [np.float32, np.float64, np.int32]) |
| 71 | + @pytest.mark.parametrize('indices_type', [np.int32, np.int64]) |
| 72 | + @pytest.mark.parametrize('tensor_shape, updates_shape, indices_shape', [ |
| 73 | + [[10, 5], [2], [2, 2]], |
| 74 | + [[4, 4, 4], [2, 4, 4], [2, 1]], |
| 75 | + [[2, 4, 8], [3], [3, 3]], |
| 76 | + [[4, 3, 5], [1, 5], [1, 2]], |
| 77 | + ]) |
| 78 | + @pytest.mark.precommit |
| 79 | + @pytest.mark.nightly |
| 80 | + def test_tensor_scatter_add(self, data_type, indices_type, |
| 81 | + tensor_shape, updates_shape, indices_shape, |
| 82 | + ie_device, precision, ir_version, temp_dir, |
| 83 | + use_legacy_frontend): |
| 84 | + if ie_device == 'GPU': |
| 85 | + pytest.skip("160549: ScatterNDUpdate(opset15) is not supported on GPU") |
| 86 | + self._test(*self.create_tensor_scatter_add_net(data_type, indices_type, |
| 87 | + tensor_shape, updates_shape, indices_shape), |
| 88 | + ie_device, precision, ir_version, temp_dir=temp_dir, |
| 89 | + use_legacy_frontend=use_legacy_frontend) |
0 commit comments