Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

removed unused packages and nodes of graph #108

Open
wants to merge 1 commit into
base: r1.15.5+nv23.03
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 92 additions & 93 deletions tensorflow/python/tools/strip_unused_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,104 +21,103 @@
import copy

from google.protobuf import text_format

from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile


def strip_unused(input_graph_def, input_node_names, output_node_names,
placeholder_type_enum):
"""Removes unused nodes from a GraphDef.

Args:
input_graph_def: A graph with nodes we want to prune.
input_node_names: A list of the nodes we use as inputs.
output_node_names: A list of the output nodes.
placeholder_type_enum: The AttrValue enum for the placeholder data type, or
a list that specifies one value per input node name.

Returns:
A `GraphDef` with all unnecessary ops removed.

Raises:
ValueError: If any element in `input_node_names` refers to a tensor instead
of an operation.
KeyError: If any element in `input_node_names` is not found in the graph.
"""
for name in input_node_names:
if ":" in name:
raise ValueError("Name '%s' appears to refer to a Tensor, "
"not a Operation." % name)

# Here we replace the nodes we're going to override as inputs with
# placeholders so that any unused nodes that are inputs to them are
# automatically stripped out by extract_sub_graph().
not_found = {name for name in input_node_names}
inputs_replaced_graph_def = graph_pb2.GraphDef()
for node in input_graph_def.node:
if node.name in input_node_names:
not_found.remove(node.name)
placeholder_node = node_def_pb2.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
if isinstance(placeholder_type_enum, list):
input_node_index = input_node_names.index(node.name)
placeholder_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=placeholder_type_enum[
input_node_index]))
else:
placeholder_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=placeholder_type_enum))
if "_output_shapes" in node.attr:
placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
"_output_shapes"])
if "shape" in node.attr:
placeholder_node.attr["shape"].CopyFrom(node.attr["shape"])
inputs_replaced_graph_def.node.extend([placeholder_node])
def strip_unused(input_graph_def, input_node_names, output_node_names, placeholder_type_enum):
"""
Removes unused nodes from a GraphDef.

Args:
input_graph_def: A graph with nodes we want to prune.
input_node_names: A list of the nodes we use as inputs.
output_node_names: A list of the output nodes.
placeholder_type_enum: The AttrValue enum for the placeholder data type, or
a list that specifies one value per input node name.

Returns:
A `GraphDef` with all unnecessary ops removed.

Raises:
ValueError: If any element in `input_node_names` refers to a tensor instead
of an operation.
KeyError: If any element in `input_node_names` is not found in the graph.
"""
for name in input_node_names:
if ":" in name:
raise ValueError("Name '%s' appears to refer to a Tensor, not an Operation." % name)

not_found = set(input_node_names)
inputs_replaced_graph_def = graph_pb2.GraphDef()

for node in input_graph_def.node:
if node.name in input_node_names:
not_found.remove(node.name)
placeholder_node = node_def_pb2.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
if isinstance(placeholder_type_enum, list):
input_node_index = input_node_names.index(node.name)
placeholder_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=placeholder_type_enum[input_node_index]))
else:
placeholder_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=placeholder_type_enum))
if "_output_shapes" in node.attr:
placeholder_node.attr["_output_shapes"].CopyFrom(node.attr["_output_shapes"])
if "shape" in node.attr:
placeholder_node.attr["shape"].CopyFrom(node.attr["shape"])
inputs_replaced_graph_def.node.extend([placeholder_node])
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

if not_found:
raise KeyError("The following input nodes were not found: %s" % ", ".join(not_found))

output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names)
return output_graph_def

def strip_unused_from_files(input_graph, input_binary, output_graph, output_binary, input_node_names, output_node_names, placeholder_type_enum):
"""
Removes unused nodes from a graph file.

Args:
input_graph: Path to the input graph file.
input_binary: Boolean indicating whether the input graph file is in binary format.
output_graph: Path to save the output graph file.
output_binary: Boolean indicating whether to save the output graph file in binary format.
input_node_names: Comma-separated string of input node names.
output_node_names: Comma-separated string of output node names.
placeholder_type_enum: The AttrValue enum for the placeholder data type, or
a list that specifies one value per input node name.

Returns:
None
"""
if not gfile.Exists(input_graph):
raise FileNotFoundError("Input graph file '%s' does not exist!" % input_graph)

if not output_node_names:
raise ValueError("You need to supply the name of a node to --output_node_names.")

input_graph_def = graph_pb2.GraphDef()
mode = "rb" if input_binary else "r"
with gfile.GFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_graph_def)

output_graph_def = strip_unused(input_graph_def, input_node_names.split(","), output_node_names.split(","), placeholder_type_enum)

if output_binary:
with gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

if not_found:
raise KeyError("The following input nodes were not found: %s" % not_found)

output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
output_node_names)
return output_graph_def

with gfile.GFile(output_graph, "w") as f:
f.write(text_format.MessageToString(output_graph_def))

def strip_unused_from_files(input_graph, input_binary, output_graph,
output_binary, input_node_names, output_node_names,
placeholder_type_enum):
"""Removes unused nodes from a graph file."""

if not gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1

if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1

input_graph_def = graph_pb2.GraphDef()
mode = "rb" if input_binary else "r"
with gfile.GFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_graph_def)

output_graph_def = strip_unused(input_graph_def,
input_node_names.split(","),
output_node_names.split(","),
placeholder_type_enum)

if output_binary:
with gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
else:
with gfile.GFile(output_graph, "w") as f:
f.write(text_format.MessageToString(output_graph_def))
print("%d ops in the final graph." % len(output_graph_def.node))
print("%d ops in the final graph." % len(output_graph_def.node))