Skip to content

Commit b409ea1

Browse files
authored
[TF FE] Support TF1 While Control flow (openvinotoolkit#20105)
* [TF FE] Support TF1 While Control flow Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Apply code-style fix * Update API for OpPlace to store back edge * Fix build: no rvalue by reference passing * Fix build issue: correct type * Fix TF FE unit-tests * Apply code-review feedback: remove unused vars * Fix fusing complicated case of TF1 While * Remove unused variable * Update MO unit test * Fix layer tests for While * Handle Switch and NextIteration nodes connected directly Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
1 parent e2501a6 commit b409ea1

23 files changed

+694
-168
lines changed

src/frontends/tensorflow/src/input_model.cpp

+29-3
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ void InputModel::InputModelTFImpl::load_places() {
137137
auto op_name = node_decoder->get_op_name();
138138
auto op_type = node_decoder->get_op_type();
139139

140+
if (op_type == "Placeholder" && op_name.rfind("unused_control_flow_input", 0) != std::string::npos) {
141+
continue;
142+
}
143+
140144
if (m_telemetry) {
141145
op_statistics[op_type]++;
142146
}
@@ -320,9 +324,6 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::topologicall
320324
std::stack<std::shared_ptr<OpPlace>> ops_to_do;
321325
std::unordered_set<std::shared_ptr<OpPlace>> ops_done;
322326

323-
// TODO: implement logic to check direct cycles in the graph
324-
// and break them
325-
// probably not only NextIteration can generate cycles
326327
for (const auto& output_place : m_outputs) {
327328
FRONT_END_GENERAL_CHECK(output_place->get_names().size() > 0, "TensorPlace must have at least one name.");
328329
auto output_place_name = output_place->get_names()[0];
@@ -336,6 +337,23 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::topologicall
336337
ops_to_do.push(output_operation_place);
337338
}
338339

340+
// walk through all NextIteration nodes and put their producers into ops_to_do
341+
// this is needed to avoid missed nodes in the body graph of TF1 While operation
342+
for (const auto& op_place : m_op_places) {
343+
auto op_decoder = op_place->get_decoder();
344+
if (op_decoder->get_op_type() == "NextIteration") {
345+
std::string producer_name;
346+
std::string producer_output_port_name;
347+
size_t producer_output_port_idx;
348+
op_decoder->get_input_node(0, producer_name, producer_output_port_name, producer_output_port_idx);
349+
FRONT_END_GENERAL_CHECK(m_op_places_map.count(producer_name),
350+
"[TensorFlow Frontend] internal error or inconsistent model: producer of "
351+
"NextIteration is not found among operation places " +
352+
producer_name);
353+
ops_to_do.push(m_op_places_map.at(producer_name));
354+
}
355+
}
356+
339357
// the traversing algorithm to compute topologically sorted nodes is taken from topological_sort in
340358
// core/graph_util.hpp
341359
while (ops_to_do.size() > 0) {
@@ -350,6 +368,14 @@ std::vector<std::shared_ptr<OpPlace>> InputModel::InputModelTFImpl::topologicall
350368
if (current_operation_type == "NextIteration") {
351369
// break the cycle created by NextIteration
352370
input_count = 0;
371+
std::string producer_name;
372+
std::string producer_output_port_name;
373+
size_t producer_output_port_idx;
374+
current_operation_decoder->get_input_node(0,
375+
producer_name,
376+
producer_output_port_name,
377+
producer_output_port_idx);
378+
current_operation_place->set_next_iteration_back_edge(producer_name, producer_output_port_idx);
353379
}
354380

355381
for (size_t input_port_idx = 0; input_port_idx < input_count; ++input_port_idx) {
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (C) 2018-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "helper_ops/enter.hpp"
6+
7+
#include "common_op_table.hpp"
8+
#include "openvino/frontend/tensorflow/node_context.hpp"
9+
#include "utils.hpp"
10+
11+
using namespace std;
12+
using namespace ov;
13+
using namespace ov::frontend::tensorflow;
14+
15+
namespace ov {
16+
namespace frontend {
17+
namespace tensorflow {
18+
namespace op {
19+
20+
OutputVector translate_enter_op(const NodeContext& node) {
21+
default_op_checks(node, 1, {"Enter"});
22+
auto data = node.get_input(0);
23+
auto frame_name = node.get_attribute<string>("frame_name");
24+
25+
auto enter_node = make_shared<Enter>(data, frame_name, node.get_decoder());
26+
set_node_name(node.get_name(), enter_node);
27+
28+
return enter_node->outputs();
29+
}
30+
31+
} // namespace op
32+
} // namespace tensorflow
33+
} // namespace frontend
34+
} // namespace ov
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (C) 2018-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "helper_ops/exit.hpp"
6+
7+
#include "common_op_table.hpp"
8+
#include "openvino/frontend/tensorflow/node_context.hpp"
9+
#include "utils.hpp"
10+
11+
using namespace std;
12+
using namespace ov;
13+
using namespace ov::frontend::tensorflow;
14+
15+
namespace ov {
16+
namespace frontend {
17+
namespace tensorflow {
18+
namespace op {
19+
20+
OutputVector translate_exit_op(const NodeContext& node) {
21+
default_op_checks(node, 1, {"Exit"});
22+
auto data = node.get_input(0);
23+
24+
auto exit_node = make_shared<Exit>(data, node.get_decoder());
25+
set_node_name(node.get_name(), exit_node);
26+
27+
return exit_node->outputs();
28+
}
29+
30+
} // namespace op
31+
} // namespace tensorflow
32+
} // namespace frontend
33+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (C) 2018-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "helper_ops/loop_cond.hpp"
6+
7+
#include "common_op_table.hpp"
8+
#include "openvino/frontend/tensorflow/node_context.hpp"
9+
#include "utils.hpp"
10+
11+
using namespace std;
12+
using namespace ov;
13+
using namespace ov::op;
14+
using namespace ov::frontend::tensorflow;
15+
16+
namespace ov {
17+
namespace frontend {
18+
namespace tensorflow {
19+
namespace op {
20+
21+
OutputVector translate_loop_cond_op(const NodeContext& node) {
22+
default_op_checks(node, 1, {"LoopCond"});
23+
auto input = node.get_input(0);
24+
25+
auto loop_cond_node = make_shared<LoopCond>(input, node.get_decoder());
26+
set_node_name(node.get_name(), loop_cond_node);
27+
28+
return loop_cond_node->outputs();
29+
}
30+
31+
} // namespace op
32+
} // namespace tensorflow
33+
} // namespace frontend
34+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (C) 2018-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "helper_ops/next_iteration.hpp"
6+
7+
#include "common_op_table.hpp"
8+
#include "helper_ops/merge.hpp"
9+
#include "openvino/frontend/tensorflow/node_context.hpp"
10+
#include "utils.hpp"
11+
12+
using namespace std;
13+
using namespace ov;
14+
using namespace ov::frontend::tensorflow;
15+
16+
namespace ov {
17+
namespace frontend {
18+
namespace tensorflow {
19+
namespace op {
20+
21+
OutputVector translate_next_iteration_op(const NodeContext& node) {
22+
default_op_checks(node, 0, {"NextIteration"});
23+
24+
auto next_iteration_node = make_shared<NextIteration>(node.get_decoder());
25+
set_node_name(node.get_name(), next_iteration_node);
26+
27+
return next_iteration_node->outputs();
28+
}
29+
30+
} // namespace op
31+
} // namespace tensorflow
32+
} // namespace frontend
33+
} // namespace ov

src/frontends/tensorflow/src/op/partitioned_call.cpp

+6-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#include "common_op_table.hpp"
66
#include "input_model.hpp"
7-
#include "openvino/opsets/opset10.hpp"
7+
#include "tf_utils.hpp"
88

99
using namespace std;
1010
using namespace ov;
@@ -18,7 +18,7 @@ OutputVector translate_partitioned_call_op(const NodeContext& node) {
1818
auto node_name = node.get_name();
1919
auto translate_session = node.get_translate_session();
2020
FRONT_END_GENERAL_CHECK(translate_session, "[TensorFlow Frontend] Internal error: Translate session is nullptr.");
21-
auto operation_type = node.get_attribute<std::string>("f");
21+
auto operation_type = node.get_attribute<string>("f");
2222

2323
// prepare a vector of inputs
2424
OutputVector ov_inputs;
@@ -33,21 +33,20 @@ OutputVector translate_partitioned_call_op(const NodeContext& node) {
3333
// of StatefulPartitionedCall. And because otherwise they will cause a duplicates. But we need to keep them
3434
// for "internal functions of Saved Model", which are named "__inference_signature_wrapper" or
3535
// "__inference_wrapped_model".
36-
auto body_model = translate_session->get_body_ov_model(operation_type,
37-
ov_inputs,
38-
operation_type.find("wrappe") == std::string::npos);
36+
auto body_model =
37+
translate_session->get_body_ov_model(operation_type, ov_inputs, operation_type.find("wrappe") == string::npos);
3938
FRONT_END_OP_CONVERSION_CHECK(
4039
body_model,
4140
"[TensorFlow Frontend] Internal error or incorrect input model: body graph is not found for " + operation_type +
4241
".");
4342

4443
// inject the body graph into the parent graph
4544
OutputVector ov_outputs;
46-
translate_session->inject_body_model(body_model, operation_type, ov_inputs, ov_outputs);
45+
inject_body_model(body_model, operation_type, ov_inputs, ov_outputs);
4746

4847
// set output tensor names
4948
for (size_t idx = 0; idx < ov_outputs.size(); ++idx) {
50-
set_out_name({node_name + ":" + std::to_string(idx)}, ov_outputs[idx]);
49+
set_out_name({node_name + ":" + to_string(idx)}, ov_outputs[idx]);
5150
}
5251

5352
return ov_outputs;

src/frontends/tensorflow/src/op/while.cpp

+5-81
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44

55
#include "common_op_table.hpp"
66
#include "input_model.hpp"
7-
#include "openvino/opsets/opset10.hpp"
7+
#include "tf_utils.hpp"
88

99
using namespace std;
1010
using namespace ov;
11-
using namespace ov::opset10;
1211

1312
namespace ov {
1413
namespace frontend {
@@ -21,7 +20,7 @@ OutputVector translate_while_op(const NodeContext& node) {
2120
auto input_size_t = node.get_input_size();
2221
auto input_size = static_cast<int>(input_size_t);
2322

24-
ov::OutputVector ov_inputs;
23+
OutputVector ov_inputs;
2524
for (int input_ind = 0; input_ind < input_size; ++input_ind) {
2625
ov_inputs.push_back(node.get_input(input_ind));
2726
}
@@ -30,8 +29,8 @@ OutputVector translate_while_op(const NodeContext& node) {
3029
translate_session,
3130
"[TensorFlow Frontend] Internal error: Translate session is nullptr.");
3231
// retrieve condition and body graphs
33-
auto cond_type = node.get_attribute<std::string>("cond");
34-
auto body_type = node.get_attribute<std::string>("body");
32+
auto cond_type = node.get_attribute<string>("cond");
33+
auto body_type = node.get_attribute<string>("body");
3534
auto cond_model = translate_session->get_body_ov_model(cond_type, ov_inputs);
3635
TENSORFLOW_OP_VALIDATION(
3736
node,
@@ -43,82 +42,7 @@ OutputVector translate_while_op(const NodeContext& node) {
4342
body_model,
4443
"[TensorFlow Frontend] Internal error or incorrect input model. Cannot find body graph with name " + body_type);
4544

46-
// inject condition body graph prior to Loop node
47-
// to check condition before to start iterations
48-
auto cond_params = cond_model->get_parameters();
49-
// type setting for body graph parameters is needed for TensorList support since DT_VARIANT type is present
50-
// also for more accurate execution_condition variable shape deducing we need shape inference for condition graph
51-
for (int input_ind = 0; input_ind < input_size; ++input_ind) {
52-
cond_params[input_ind]->set_element_type(node.get_input(input_ind).get_element_type());
53-
cond_params[input_ind]->set_partial_shape(node.get_input(input_ind).get_partial_shape());
54-
}
55-
cond_model->validate_nodes_and_infer_types();
56-
57-
auto cond_prior = cond_model->clone();
58-
ov::OutputVector ov_outputs;
59-
translate_session->inject_body_model(cond_prior, node.get_name() + "/cond", ov_inputs, ov_outputs);
60-
TENSORFLOW_OP_VALIDATION(
61-
node,
62-
ov_outputs.size() == 1,
63-
"[TensorFlow Frontend] Internal error or inconsistent model: condition body must contain one Result node.");
64-
auto exec_cond = ov_outputs[0];
65-
auto trip_count = make_shared<Constant>(element::i32, Shape{}, -1);
66-
auto loop = make_shared<Loop>(trip_count, exec_cond);
67-
68-
// prepare body model to be set for the Loop node
69-
// note that condition should be computed on the updated input
70-
// because this is while(cond) {} construction,
71-
// that is why condition graph is stitched to the body results
72-
auto body_params = body_model->get_parameters();
73-
auto body_results = body_model->get_results();
74-
auto cond_results = cond_model->get_results();
75-
auto cond_params_size = cond_params.size();
76-
TENSORFLOW_OP_VALIDATION(node,
77-
body_params.size() == input_size_t,
78-
"[TensorFlow Frontend] Internal error or inconsistent model: body graph "
79-
" must have the same number of Parameter nodes as a number of inputs to While.");
80-
TENSORFLOW_OP_VALIDATION(node,
81-
body_results.size() == input_size_t,
82-
"[TensorFlow Frontend] Internal error or inconsistent model: body graphs "
83-
" must have the same number of Result nodes as a number of inputs to While.");
84-
TENSORFLOW_OP_VALIDATION(node,
85-
cond_params.size() == input_size_t,
86-
"[TensorFlow Frontend] Internal error or inconsistent model: condition graph "
87-
" must have the same number of Parameter nodes as a number of inputs to While.");
88-
for (size_t param_ind = 0; param_ind < cond_params_size; ++param_ind) {
89-
cond_params[param_ind]->output(0).replace(body_results[param_ind]->input_value(0));
90-
}
91-
92-
// update body model with the new result that corresponds to execution condition
93-
TENSORFLOW_OP_VALIDATION(
94-
node,
95-
cond_results.size() == 1 && cond_results[0],
96-
"[TensorFlow Frontend] Internal error or inconsistent model: condition body must contain one Result node.");
97-
auto body_condition_output_idx = static_cast<int64_t>(body_results.size());
98-
body_model->add_results(cond_results);
99-
100-
// type setting for body graph parameters is needed for TensorList support since DT_VARIANT type is present
101-
for (int input_ind = 0; input_ind < input_size; ++input_ind) {
102-
body_params[input_ind]->set_element_type(node.get_input(input_ind).get_element_type());
103-
}
104-
105-
// set data for the Loop node
106-
loop->set_function(body_model);
107-
108-
for (int input_ind = 0; input_ind < input_size; ++input_ind) {
109-
loop->set_merged_input(body_params[input_ind],
110-
node.get_input(input_ind),
111-
body_results[input_ind]->input_value(0));
112-
}
113-
loop->set_special_body_ports({-1, body_condition_output_idx});
114-
115-
// set external outputs for Loop node
116-
// do not get execution condition outside of the Loop node
117-
for (size_t output_ind = 0; output_ind < input_size_t; ++output_ind) {
118-
loop->get_iter_value(body_results[output_ind]);
119-
}
120-
loop->validate_and_infer_types();
121-
45+
auto loop = create_loop_for_tf_while(node.get_name(), body_model, cond_model, ov_inputs);
12246
set_node_name(node.get_name(), loop);
12347
return loop->outputs();
12448
}

src/frontends/tensorflow/src/op_table.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,18 @@ namespace op {
2222

2323
TF_OP_CONVERTER(translate_assignvariable_op);
2424
TF_OP_CONVERTER(translate_block_lstm_op);
25+
TF_OP_CONVERTER(translate_enter_op);
26+
TF_OP_CONVERTER(translate_exit_op);
2527
TF_OP_CONVERTER(translate_fifo_queue_op);
2628
TF_OP_CONVERTER(translate_gru_block_cell_op);
2729
TF_OP_CONVERTER(translate_hash_table_op);
2830
TF_OP_CONVERTER(translate_if_op);
2931
TF_OP_CONVERTER(translate_iterator_get_next_op);
3032
TF_OP_CONVERTER(translate_iterator_op);
33+
TF_OP_CONVERTER(translate_loop_cond_op);
3134
TF_OP_CONVERTER(translate_merge_op);
3235
TF_OP_CONVERTER(translate_mergev2checkpoint_op);
36+
TF_OP_CONVERTER(translate_next_iteration_op);
3337
TF_OP_CONVERTER(translate_partitioned_call_op);
3438
TF_OP_CONVERTER(translate_placeholder_linked_op);
3539
TF_OP_CONVERTER(translate_queue_dequeue_op);
@@ -310,6 +314,12 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
310314
// XLA operations
311315
{"XlaConvV2", CreatorFunction(translate_xla_conv_v2_op)},
312316
{"XlaDotV2", CreatorFunction(translate_xla_dot_op)},
317+
318+
// TF1 Control Flow operations
319+
{"Enter", CreatorFunction(translate_enter_op)},
320+
{"Exit", CreatorFunction(translate_exit_op)},
321+
{"LoopCond", CreatorFunction(translate_loop_cond_op)},
322+
{"NextIteration", CreatorFunction(translate_next_iteration_op)},
313323
};
314324
};
315325
} // namespace op

0 commit comments

Comments
 (0)