forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_tf_NestedWhile.py
83 lines (61 loc) · 3.55 KB
/
test_tf_NestedWhile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from common.tf_layer_test_class import CommonTFLayerTest
class TestNestedWhile(CommonTFLayerTest):
def create_simple_while(self):
import tensorflow as tf
g = tf.Graph()
with g.as_default():
x = tf.compat.v1.placeholder(tf.float32, shape=(3, 2))
v = tf.constant([1, 2, 3], dtype=tf.int32, shape=[3])
i = tf.constant([0], dtype=tf.int32, shape=[1])
a_combined = tf.zeros([1, 2], dtype=tf.float32)
b_combined = tf.zeros([1, 2], dtype=tf.float32)
def body(x_arg, v_arg, i_arg, a_combined_arg, b_combined_arg):
x_slice = tf.slice(x_arg, [0, 0], [1, x_arg.shape[1]])
i_arg = tf.add(i_arg, 1)
a_combined_arg = tf.add(a_combined_arg, x_slice)
return x_arg, v_arg, i_arg, a_combined_arg, b_combined_arg
while_condition = lambda x, v, i, a_combined, b_combined: i < v.shape[0]
tf.while_loop(while_condition, body, [x, v, i, a_combined, b_combined],
name="while_node")
return g, None
def create_nested_while(self):
import tensorflow as tf
g = tf.Graph()
with g.as_default():
x = tf.compat.v1.placeholder(tf.float32, shape=(3, 2))
v = tf.constant([1, 2, 3], dtype=tf.int32, shape=[3])
i = tf.constant([0], dtype=tf.int32, shape=[1])
a_combined = tf.zeros([1, 2], dtype=tf.float32)
b_combined = tf.zeros([1, 2], dtype=tf.float32)
def body(x_arg, v_arg, i_arg, a_combined_arg, b_combined_arg):
x_slice = tf.slice(x_arg, [0, 0], [1, x_arg.shape[1]])
v_slice = tf.slice(v_arg, [0], [1])
j = tf.constant([0], dtype=tf.int32, shape=[1])
def body_supp(x_slice_arg, v_slice_arg, j_arg, b_combined_arg_arg):
j_arg = tf.add(j_arg, 1)
b_combined_arg_arg = tf.add(b_combined_arg_arg, x_slice_arg)
return x_slice_arg, v_slice_arg, j_arg, b_combined_arg_arg
while_condition_supp = lambda x_slice, v_slice, j, b_combined: tf.less(j, v_slice)
x_slice, v_slice, j, b_combined_arg = tf.while_loop(while_condition_supp, body_supp,
[x_slice, v_slice, j, b_combined_arg])
i_arg = tf.add(i_arg, 1)
a_combined_arg = tf.add(a_combined_arg, x_slice)
return x_arg, v_arg, i_arg, a_combined_arg, b_combined_arg
while_condition = lambda x, v, i, a_combined, b_combined: i < v.shape[0]
tf.while_loop(while_condition, body, [x, v, i, a_combined, b_combined],
name="while_node")
return g, None
@pytest.mark.nightly
def test_simple_while(self, ie_device, precision, ir_version, temp_dir, use_new_frontend,
use_old_api):
self._test(*self.create_simple_while(), ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_nested_while(self, ie_device, precision, ir_version, temp_dir, use_new_frontend,
use_old_api):
self._test(*self.create_nested_while(), ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)