forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnomnigraph_transformations_test.py
144 lines (129 loc) · 5.63 KB
/
nomnigraph_transformations_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from caffe2.python import core, workspace
from caffe2.python import test_util as tu
import caffe2.python.nomnigraph as ng
from caffe2.python.nomnigraph_transformations import transpose_network
import numpy as np
from hypothesis import given
import hypothesis.strategies as st
class TestNomnigraphTransformations(tu.TestCase):
def test_simple_replace(self):
net = core.Net("name")
net.FC(["X", "W"], ["Y"])
nn = ng.NNModule(net)
fc = nn.controlFlow[0]
add = nn.createNode(core.CreateOperator("Add", ["X"], ["Y"], engine="CUDNN"))
nn.replaceNode(fc, add)
nn.deleteNode(fc)
# Test it out
new_netdef = nn.convertToCaffe2Proto()
workspace.ResetWorkspace()
workspace.FeedBlob("X", np.array([1, 2, 3]))
workspace.FeedBlob("W", np.array([1, 2, 3]))
workspace.RunNetOnce(new_netdef)
out = workspace.FetchBlob("Y")
expected_out = np.array([2, 4, 6])
np.testing.assert_almost_equal(out, expected_out)
def test_simple_rewire(self):
net = core.Net("name")
# Rewire this so that we get
# c = Add(a, d)
# e = Mul(c, b)
#
# if a = 1, b = 2, d = 3
# we get 8: (1 + 3) * 2
# as opposed to 7: 1 + (3 * 2)
net.Mul(["a", "b"], ["c"])
net.Add(["c", "d"], ["e"])
nn = ng.NNModule(net)
mul = nn.controlFlow[0]
add = nn.controlFlow[1]
a = mul.inputs[0]
b = mul.inputs[1]
c = mul.outputs[0]
d = add.inputs[1]
e = add.outputs[0]
nn.deleteEdge(a, mul)
nn.deleteEdge(b, mul)
nn.deleteEdge(mul, c)
nn.deleteEdge(c, add)
nn.deleteEdge(d, add)
nn.deleteEdge(add, e)
nn.createEdge(a, add)
nn.createEdge(d, add)
nn.createEdge(add, c)
nn.createEdge(c, mul)
nn.createEdge(b, mul)
nn.createEdge(mul, e)
# Test it out
new_netdef = nn.convertToCaffe2Proto()
workspace.ResetWorkspace()
workspace.FeedBlob("a", np.array([1, 1, 1]))
workspace.FeedBlob("b", np.array([2, 2, 2]))
workspace.FeedBlob("d", np.array([3, 3, 3]))
workspace.RunNetOnce(new_netdef)
out = workspace.FetchBlob("e")
expected_out = np.array([8, 8, 8])
np.testing.assert_almost_equal(out, expected_out)
@given(
batch_size=st.integers(16, 20),
channels=st.integers(1, 10),
height=st.integers(10, 15),
width=st.integers(10, 15),
seed=st.integers(0, 65535),
kernel=st.integers(3, 5),
)
def test_transpose_network(self, batch_size, channels, height, width, seed,
kernel):
net = core.Net("net")
net.Conv(["X", "w1", "b1"], ["c1"], stride=1, pad=0, kernel=kernel)
net.Conv(["X", "w2", "b2"], ["c2"], stride=1, pad=0, kernel=kernel)
# c1 and c2: batch_size, 2*channels, height - kernel + 1, width - kernel + 1
net.Conv(["c1", "w3", "b3"], ["c3"], stride=1, pad=0, kernel=kernel)
net.Conv(["c1", "w4", "b4"], ["c4"], stride=1, pad=0, kernel=kernel)
# c3 and c4: batch_size, 2*channels, height - 2*kernel + 2, width - 2*kernel + 2
net.Flatten(["c3"], "c3f")
net.Flatten(["c4"], "c4f")
net.Flatten(["X"], "Xf")
net.Concat(["c3f", "c4f", "Xf"], ["out", "split_info"], axis=1, add_axis=0)
np.random.seed(seed)
workspace.ResetWorkspace()
tu.randBlobFloat32("X", batch_size, channels, height, width)
tu.randBlobsFloat32(["w1", "w2"], 2 * channels, channels, kernel, kernel)
tu.randBlobsFloat32(["b1", "b2"], 2 * channels)
tu.randBlobsFloat32(["w3", "w4"], 4 * channels, 2 * channels, kernel, kernel)
tu.randBlobsFloat32(["b3", "b4"], 4 * channels)
all_inp_names = ["X", "w1", "w2", "b1", "b2", "w3", "w4", "b3", "b4"]
all_input = workspace.FetchBlobs(all_inp_names)
workspace.RunNetOnce(net)
preTransformC1 = workspace.FetchBlob("c1")
preTransformC3 = workspace.FetchBlob("c3")
preTransformOut = workspace.FetchBlob("out")
nn = ng.NNModule(net)
preTransformNumOperators = len(nn.operators)
preTransformNumTensors = len(nn.tensors)
transpose_network(nn)
new_netdef = nn.convertToCaffe2Proto()
postTransformNumOperators = len(nn.operators)
postTransformNumTensors = len(nn.tensors)
# The minimal number of additional operators and tensors is at least one
# NCHW2NHWC operator and tensor for each channel-based input tensor
# and a NHWC2NCHW operator and tensor for the output of each convolution
# X, w1, w2, w3, w4 are channel-based inputs
# c1, c2, c3, c4 are the outputs of convolutions
# i.e. a total of 9.
self.assertEqual(postTransformNumOperators,
preTransformNumOperators + 9,
"expected 9 additional operators")
self.assertEqual(postTransformNumTensors,
preTransformNumTensors + 9,
"expected 9 additional tensors")
workspace.ResetWorkspace()
for name, val in zip(all_inp_names, all_input):
workspace.FeedBlob(name, val)
workspace.RunNetOnce(new_netdef)
postTransformC1 = workspace.FetchBlob("c1")
postTransformC3 = workspace.FetchBlob("c3")
postTransformOut = workspace.FetchBlob("out")
np.testing.assert_almost_equal(postTransformC1, preTransformC1, 1)
np.testing.assert_almost_equal(postTransformC3, preTransformC3, 1)
np.testing.assert_almost_equal(postTransformOut, preTransformOut, 1)