Skip to content

Commit 842fedc

Browse files
sumhajrkazants
andauthored
[TF FE] Add support for TensorScatterAdd in TF FE (openvinotoolkit#28419)
**Overview**: This pull request fixes openvinotoolkit#25050 All testcases passed Continuation of PR openvinotoolkit#26481 **Dependencies**: - No dependencies on other pull requests. **CC**: @rkazants, @mlukasze --------- Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
1 parent e390175 commit 842fedc

File tree

5 files changed

+121
-1
lines changed

5 files changed

+121
-1
lines changed

src/frontends/tensorflow/docs/supported_ops.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1314,7 +1314,7 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV
13141314
| TensorListSetItem | YES | |
13151315
| TensorListSplit | NO | |
13161316
| TensorListStack | YES | |
1317-
| TensorScatterAdd | NO | |
1317+
| TensorScatterAdd | YES | |
13181318
| TensorScatterMax | NO | |
13191319
| TensorScatterMin | NO | |
13201320
| TensorScatterSub | NO | |

src/frontends/tensorflow/src/op_table.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
414414
{"TensorListReserve", CreatorFunction(translate_tensor_list_reserve_op)},
415415
{"TensorListResize", CreatorFunction(translate_tensor_list_resize_op)},
416416
{"TensorListConcatV2", CreatorFunction(translate_tensor_list_concat_v2_op)},
417+
{"TensorScatterAdd", CreatorFunction(translate_tensor_scatter_add_op)},
417418
{"TensorScatterUpdate", CreatorFunction(translate_tensor_scatter_update_op)},
418419
{"Tile", CreatorFunction(translate_tile_op)},
419420
{"ToBool", CreatorFunction(translate_tobool_op)},

src/frontends/tensorflow_common/include/common_op_table.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ OP_CONVERTER(translate_tensor_list_set_item_op);
177177
OP_CONVERTER(translate_tensor_list_stack_op);
178178
OP_CONVERTER(translate_tensor_list_resize_op);
179179
OP_CONVERTER(translate_tensor_list_concat_v2_op);
180+
OP_CONVERTER(translate_tensor_scatter_add_op);
180181
OP_CONVERTER(translate_tensor_scatter_update_op);
181182
OP_CONVERTER(translate_tile_op);
182183
OP_CONVERTER(translate_tobool_op);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "common_op_table.hpp"
6+
#include "openvino/op/scatter_nd_update.hpp"
7+
8+
using namespace std;
9+
using namespace ov::op;
10+
11+
namespace ov {
12+
namespace frontend {
13+
namespace tensorflow {
14+
namespace op {
15+
OutputVector translate_tensor_scatter_add_op(const NodeContext& node) {
16+
default_op_checks(node, 3, {"TensorScatterAdd"});
17+
auto data = node.get_input(0);
18+
auto indices = node.get_input(1);
19+
auto updates = node.get_input(2);
20+
auto reduction = v15::ScatterNDUpdate::Reduction::SUM;
21+
auto scatter_add_op = make_shared<v15::ScatterNDUpdate>(data, indices, updates, reduction);
22+
set_node_name(node.get_name(), scatter_add_op);
23+
24+
return {scatter_add_op};
25+
}
26+
} // namespace op
27+
} // namespace tensorflow
28+
} // namespace frontend
29+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (C) 2018-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import numpy as np
5+
import pytest
6+
import tensorflow as tf
7+
from common.tf_layer_test_class import CommonTFLayerTest
8+
9+
rng = np.random.default_rng(872173)
10+
11+
12+
class TestTensorScatterAdd(CommonTFLayerTest):
13+
def _prepare_input(self, inputs_info):
14+
assert 'tensor:0' in inputs_info
15+
assert 'indices:0' in inputs_info
16+
assert 'updates:0' in inputs_info
17+
18+
tensor_shape = inputs_info['tensor:0']
19+
updates_shape = inputs_info['updates:0']
20+
indices_shape = inputs_info['indices:0']
21+
22+
inputs_data = {}
23+
if np.issubdtype(self.data_type, np.floating):
24+
inputs_data['tensor:0'] = rng.uniform(-5.0, 5.0, tensor_shape).astype(self.data_type)
25+
inputs_data['updates:0'] = rng.uniform(-5.0, 5.0, updates_shape).astype(self.data_type)
26+
elif np.issubdtype(self.data_type, np.signedinteger):
27+
inputs_data['tensor:0'] = rng.integers(-8, 8, tensor_shape).astype(self.data_type)
28+
inputs_data['updates:0'] = rng.integers(-8, 8, updates_shape).astype(self.data_type)
29+
else:
30+
inputs_data['tensor:0'] = rng.integers(0, 8, tensor_shape).astype(self.data_type)
31+
inputs_data['updates:0'] = rng.integers(0, 8, updates_shape).astype(self.data_type)
32+
33+
indices_rows, indices_col = indices_shape
34+
35+
indices_of_tensor_shape = []
36+
for i in range(0, indices_col):
37+
indices_of_tensor_shape.append(np.arange(tensor_shape[i]))
38+
39+
mesh = np.meshgrid(*indices_of_tensor_shape)
40+
41+
all_indicies = np.stack(mesh, axis=indices_col)
42+
all_indicies = all_indicies.reshape(-1, all_indicies.shape[-1])
43+
44+
inputs_data['indices:0'] = rng.choice(all_indicies, indices_rows, replace=False).astype(self.indices_type)
45+
46+
return inputs_data
47+
48+
def create_tensor_scatter_add_net(self, data_type, indices_type, tensor_shape, updates_shape, indices_shape):
49+
self.data_type = data_type
50+
self.indices_type = indices_type
51+
self.tensor_shape = tensor_shape
52+
self.updates_shape = updates_shape
53+
self.indices_shape = indices_shape
54+
tf.compat.v1.reset_default_graph()
55+
with tf.compat.v1.Session() as sess:
56+
indices = tf.compat.v1.placeholder(indices_type, indices_shape, 'indices')
57+
tensor = tf.compat.v1.placeholder(data_type, tensor_shape, 'tensor')
58+
updates = tf.compat.v1.placeholder(data_type, updates_shape, 'updates')
59+
tf.raw_ops.TensorScatterAdd(
60+
tensor=tensor,
61+
indices=indices,
62+
updates=updates)
63+
tf.compat.v1.global_variables_initializer()
64+
tf_net = sess.graph_def
65+
66+
ref_net = None
67+
68+
return tf_net, ref_net
69+
70+
@pytest.mark.parametrize('data_type', [np.float32, np.float64, np.int32])
71+
@pytest.mark.parametrize('indices_type', [np.int32, np.int64])
72+
@pytest.mark.parametrize('tensor_shape, updates_shape, indices_shape', [
73+
[[10, 5], [2], [2, 2]],
74+
[[4, 4, 4], [2, 4, 4], [2, 1]],
75+
[[2, 4, 8], [3], [3, 3]],
76+
[[4, 3, 5], [1, 5], [1, 2]],
77+
])
78+
@pytest.mark.precommit
79+
@pytest.mark.nightly
80+
def test_tensor_scatter_add(self, data_type, indices_type,
81+
tensor_shape, updates_shape, indices_shape,
82+
ie_device, precision, ir_version, temp_dir,
83+
use_legacy_frontend):
84+
if ie_device == 'GPU':
85+
pytest.skip("160549: ScatterNDUpdate(opset15) is not supported on GPU")
86+
self._test(*self.create_tensor_scatter_add_net(data_type, indices_type,
87+
tensor_shape, updates_shape, indices_shape),
88+
ie_device, precision, ir_version, temp_dir=temp_dir,
89+
use_legacy_frontend=use_legacy_frontend)

0 commit comments

Comments
 (0)