4
4
5
5
#include " common_op_table.hpp"
6
6
#include " input_model.hpp"
7
- #include " openvino/opsets/opset10 .hpp"
7
+ #include " tf_utils .hpp"
8
8
9
9
using namespace std ;
10
10
using namespace ov ;
11
- using namespace ov ::opset10;
12
11
13
12
namespace ov {
14
13
namespace frontend {
@@ -21,7 +20,7 @@ OutputVector translate_while_op(const NodeContext& node) {
21
20
auto input_size_t = node.get_input_size ();
22
21
auto input_size = static_cast <int >(input_size_t );
23
22
24
- ov:: OutputVector ov_inputs;
23
+ OutputVector ov_inputs;
25
24
for (int input_ind = 0 ; input_ind < input_size; ++input_ind) {
26
25
ov_inputs.push_back (node.get_input (input_ind));
27
26
}
@@ -30,8 +29,8 @@ OutputVector translate_while_op(const NodeContext& node) {
30
29
translate_session,
31
30
" [TensorFlow Frontend] Internal error: Translate session is nullptr." );
32
31
// retrieve condition and body graphs
33
- auto cond_type = node.get_attribute <std:: string>(" cond" );
34
- auto body_type = node.get_attribute <std:: string>(" body" );
32
+ auto cond_type = node.get_attribute <string>(" cond" );
33
+ auto body_type = node.get_attribute <string>(" body" );
35
34
auto cond_model = translate_session->get_body_ov_model (cond_type, ov_inputs);
36
35
TENSORFLOW_OP_VALIDATION (
37
36
node,
@@ -43,82 +42,7 @@ OutputVector translate_while_op(const NodeContext& node) {
43
42
body_model,
44
43
" [TensorFlow Frontend] Internal error or incorrect input model. Cannot find body graph with name " + body_type);
45
44
46
- // inject condition body graph prior to Loop node
47
- // to check condition before to start iterations
48
- auto cond_params = cond_model->get_parameters ();
49
- // type setting for body graph parameters is needed for TensorList support since DT_VARIANT type is present
50
- // also for more accurate execution_condition variable shape deducing we need shape inference for condition graph
51
- for (int input_ind = 0 ; input_ind < input_size; ++input_ind) {
52
- cond_params[input_ind]->set_element_type (node.get_input (input_ind).get_element_type ());
53
- cond_params[input_ind]->set_partial_shape (node.get_input (input_ind).get_partial_shape ());
54
- }
55
- cond_model->validate_nodes_and_infer_types ();
56
-
57
- auto cond_prior = cond_model->clone ();
58
- ov::OutputVector ov_outputs;
59
- translate_session->inject_body_model (cond_prior, node.get_name () + " /cond" , ov_inputs, ov_outputs);
60
- TENSORFLOW_OP_VALIDATION (
61
- node,
62
- ov_outputs.size () == 1 ,
63
- " [TensorFlow Frontend] Internal error or inconsistent model: condition body must contain one Result node." );
64
- auto exec_cond = ov_outputs[0 ];
65
- auto trip_count = make_shared<Constant>(element::i32, Shape{}, -1 );
66
- auto loop = make_shared<Loop>(trip_count, exec_cond);
67
-
68
- // prepare body model to be set for the Loop node
69
- // note that condition should be computed on the updated input
70
- // because this is while(cond) {} construction,
71
- // that is why condition graph is stitched to the body results
72
- auto body_params = body_model->get_parameters ();
73
- auto body_results = body_model->get_results ();
74
- auto cond_results = cond_model->get_results ();
75
- auto cond_params_size = cond_params.size ();
76
- TENSORFLOW_OP_VALIDATION (node,
77
- body_params.size () == input_size_t ,
78
- " [TensorFlow Frontend] Internal error or inconsistent model: body graph "
79
- " must have the same number of Parameter nodes as a number of inputs to While." );
80
- TENSORFLOW_OP_VALIDATION (node,
81
- body_results.size () == input_size_t ,
82
- " [TensorFlow Frontend] Internal error or inconsistent model: body graphs "
83
- " must have the same number of Result nodes as a number of inputs to While." );
84
- TENSORFLOW_OP_VALIDATION (node,
85
- cond_params.size () == input_size_t ,
86
- " [TensorFlow Frontend] Internal error or inconsistent model: condition graph "
87
- " must have the same number of Parameter nodes as a number of inputs to While." );
88
- for (size_t param_ind = 0 ; param_ind < cond_params_size; ++param_ind) {
89
- cond_params[param_ind]->output (0 ).replace (body_results[param_ind]->input_value (0 ));
90
- }
91
-
92
- // update body model with the new result that corresponds to execution condition
93
- TENSORFLOW_OP_VALIDATION (
94
- node,
95
- cond_results.size () == 1 && cond_results[0 ],
96
- " [TensorFlow Frontend] Internal error or inconsistent model: condition body must contain one Result node." );
97
- auto body_condition_output_idx = static_cast <int64_t >(body_results.size ());
98
- body_model->add_results (cond_results);
99
-
100
- // type setting for body graph parameters is needed for TensorList support since DT_VARIANT type is present
101
- for (int input_ind = 0 ; input_ind < input_size; ++input_ind) {
102
- body_params[input_ind]->set_element_type (node.get_input (input_ind).get_element_type ());
103
- }
104
-
105
- // set data for the Loop node
106
- loop->set_function (body_model);
107
-
108
- for (int input_ind = 0 ; input_ind < input_size; ++input_ind) {
109
- loop->set_merged_input (body_params[input_ind],
110
- node.get_input (input_ind),
111
- body_results[input_ind]->input_value (0 ));
112
- }
113
- loop->set_special_body_ports ({-1 , body_condition_output_idx});
114
-
115
- // set external outputs for Loop node
116
- // do not get execution condition outside of the Loop node
117
- for (size_t output_ind = 0 ; output_ind < input_size_t ; ++output_ind) {
118
- loop->get_iter_value (body_results[output_ind]);
119
- }
120
- loop->validate_and_infer_types ();
121
-
45
+ auto loop = create_loop_for_tf_while (node.get_name (), body_model, cond_model, ov_inputs);
122
46
set_node_name (node.get_name (), loop);
123
47
return loop->outputs ();
124
48
}
0 commit comments