-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsummaries.py
301 lines (243 loc) · 8.29 KB
/
summaries.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for summarizing and describing TensorFlow graphs.
This contains functions that generate string descriptions from
TensorFlow graphs, for debugging, testing, and model size
estimation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import tensorflow as tf
from tfspecs import specs
# These are short abbreviations for common TensorFlow operations used
# in test cases with tf_structure to verify that specs_lib generates a
# graph structure with the right operations. Operations outside the
# scope of specs (e.g., Const and Placeholder) are just assigned "_"
# since they are not relevant to testing.
SHORT_NAMES_SRC = """
BiasAdd biasadd
Const _
Conv2D conv
MatMul dot
Placeholder _
Sigmoid sig
Variable var
""".split()
SHORT_NAMES = {x: y for x, y in zip(SHORT_NAMES_SRC[::2],
SHORT_NAMES_SRC[1::2])}
def _truncate_structure(x):
"""A helper function that disables recursion in tf_structure.
Some constructs (e.g., HorizontalLstm) are complex unrolled
structures and don't need to be represented in the output
of tf_structure or tf_print. This helper function defines
which tree branches should be pruned. This is a very imperfect
way of dealing with unrolled LSTM's (since it truncates
useful information as well), but it's not worth doing something
better until the new fused and unrolled ops are ready.
Args:
x: a Tensor or Op
Returns:
A bool indicating whether the subtree should be pruned.
"""
if "/HorizontalLstm/" in x.name: return True
return False
def tf_structure(x, include_shapes=False, finished=None):
"""A postfix expression summarizing the TF graph.
This is intended to be used as part of test cases to
check for gross differences in the structure of the graph.
The resulting string is not invertible or unabiguous
and cannot be used to reconstruct the graph accurately.
Args:
x: a tf.Tensor or tf.Operation
include_shapes: include shapes in the output string
finished: a set of ops that have already been output
Returns:
A string representing the structure as a string of
postfix operations.
"""
if finished is None:
finished = set()
if isinstance(x, tf.Tensor):
shape = x.get_shape().as_list()
x = x.op
else:
shape = []
if x in finished:
return " <>"
finished |= {x}
result = ""
if not _truncate_structure(x):
for y in x.inputs:
result += tf_structure(y, include_shapes, finished)
if include_shapes:
result += " %s" % (shape,)
if x.type != "Identity":
name = SHORT_NAMES.get(x.type, x.type.lower())
result += " " + name
return result
def tf_print(x, depth=0, finished=None, printer=print):
"""A simple print function for a TensorFlow graph.
Args:
x: a tf.Tensor or tf.Operation
depth: current printing depth
finished: set of nodes already output
printer: print function to use
Returns:
Total number of parameters found in the
subtree.
"""
if finished is None:
finished = set()
if isinstance(x, tf.Tensor):
shape = x.get_shape().as_list()
x = x.op
else:
shape = ""
if x.type == "Identity":
x = x.inputs[0].op
if x in finished:
printer("%s<%s> %s %s" % (" "*depth, x.name, x.type, shape))
return
finished |= {x}
printer("%s%s %s %s" % (" "*depth, x.name, x.type, shape))
if not _truncate_structure(x):
for y in x.inputs:
tf_print(y, depth+1, finished, printer=printer)
def tf_num_params(x):
"""Number of parameters in a TensorFlow subgraph.
Args:
x: root of the subgraph (Tensor, Operation)
Returns:
Total number of elements found in all Variables
in the subgraph.
"""
if isinstance(x, tf.Tensor):
shape = x.get_shape()
x = x.op
if x.type == "Variable":
return shape.num_elements()
totals = [tf_num_params(y) for y in x.inputs]
return sum(totals)
def tf_left_split(op):
"""Split the parameters of op for left recursion.
Args:
op: tf.Operation
Returns:
A tuple of the leftmost input tensor and a list of the
remaining arguments.
"""
if len(op.inputs) < 1:
return None, []
if op.type == "Concat":
return op.inputs[1], op.inputs[2:]
return op.inputs[0], op.inputs[1:]
def tf_parameter_iter(x):
"""Iterate over the left branches of a graph and yield sizes.
Args:
x: root of the subgraph (Tensor, Operation)
Yields:
A triple of name, number of params, and shape.
"""
while 1:
if isinstance(x, tf.Tensor):
shape = x.get_shape().as_list()
x = x.op
else:
shape = ""
left, right = tf_left_split(x)
totals = [tf_num_params(y) for y in right]
total = sum(totals)
yield x.name, total, shape
if left is None: break
x = left
def _combine_filter(x):
"""A filter for combining successive layers with similar names."""
last_name = None
last_total = 0
last_shape = None
for name, total, shape in x:
name = re.sub("/.*", "", name)
if name == last_name:
last_total += total
continue
if last_name is not None:
yield last_name, last_total, last_shape
last_name = name
last_total = total
last_shape = shape
if last_name is not None:
yield last_name, last_total, last_shape
def tf_parameter_summary(x, printer=print, combine=True):
"""Summarize parameters by depth.
Args:
x: root of the subgraph (Tensor, Operation)
printer: print function for output
combine: combine layers by top-level scope
"""
seq = tf_parameter_iter(x)
if combine: seq = _combine_filter(seq)
seq = reversed(list(seq))
for name, total, shape in seq:
printer("%10d %-20s %s" % (total, name, shape))
def tf_spec_structure(spec, inputs=None, input_shape=None,
input_type=tf.float32):
"""Return a postfix representation of the specification.
This is intended to be used as part of test cases to
check for gross differences in the structure of the graph.
The resulting string is not invertible or unabiguous
and cannot be used to reconstruct the graph accurately.
Args:
spec: specification
inputs: input to the spec construction (usually a Tensor)
input_shape: tensor shape (in lieu of inputs)
input_type: type of the input tensor
Returns:
A string with a postfix representation of the
specification.
"""
if inputs is None:
inputs = tf.placeholder(input_type, input_shape)
outputs = specs.create_net(spec, inputs)
return str(tf_structure(outputs).strip())
def tf_spec_summary(spec, inputs=None, input_shape=None, input_type=tf.float32):
"""Output a summary of the specification.
This prints a list of left-most tensor operations and summarized the
variables found in the right branches. This kind of representation
is particularly useful for networks that are generally structured
like pipelines.
Args:
spec: specification
inputs: input to the spec construction (usually a Tensor)
input_shape: optional shape of input
input_type: type of the input tensor
"""
if inputs is None:
inputs = tf.placeholder(input_type, input_shape)
outputs = specs.create_net(spec, inputs)
tf_parameter_summary(outputs)
def tf_spec_print(spec, inputs=None, input_shape=None, input_type=tf.float32):
"""Print a tree representing the spec.
Args:
spec: specification
inputs: input to the spec construction (usually a Tensor)
input_shape: optional shape of input
input_type: type of the input tensor
"""
if inputs is None:
inputs = tf.placeholder(input_type, input_shape)
outputs = specs.create_net(spec, inputs)
tf_print(outputs)