Skip to content

Commit 91b926f

Browse files
kflufacebook-github-bot
authored andcommitted
Add fx2trt pass for removing duplicate output args (pytorch#64461)
Summary: Pull Request resolved: pytorch#64461 Fx2TRT does not support duplicate nodes in the output args tuple. This pass removes duplicate output args from the target subnets and fixes their uses in the top level module where the subnets are called. This pass must be called after acc split on the top-level net and subsequent calls to the acc trace on the subnets. This pass will change both the subnets and top level module. Test Plan: Run: ``` buck run mode/opt -c python.package_style=inplace //caffe2/torch/fb/fx2trt/tests/passes/:test_remove_duplicate_output_args ``` Reviewed By: yinghai Differential Revision: D30740499 fbshipit-source-id: 98459f7677980b21c7bffda918158001285572db
1 parent 39aeb3b commit 91b926f

File tree

1 file changed

+134
-0
lines changed

1 file changed

+134
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#!/usr/bin/env python3
2+
3+
import operator
4+
import typing as t
5+
import logging
6+
import torch.fx as fx
7+
import dataclasses as dc
8+
9+
10+
_LOGGER = logging.getLogger(__name__)
11+
12+
13+
def remove_duplicate_output_args(
14+
top_level: fx.GraphModule,
15+
target_subnets: t.Collection[str]
16+
) -> t.Mapping[str, "RemoveDuplicateResult"]:
17+
"""Removes duplicate output args.
18+
19+
This pass removes duplicate output args from the target subnets and fixes
20+
their uses in the top level module where the subnets are called. This pass
21+
must be called after acc split on the top-level net and subsequent calls to
22+
the acc trace on the subnets.
23+
24+
This pass will change both the subnets and top level module.
25+
26+
Returns:
27+
a mapping of the target subnet name to its dedupcate result
28+
"""
29+
30+
processed_subnets = {}
31+
for node in top_level.graph.nodes: # type: fx.Node
32+
if node.op == "call_module" and node.name in target_subnets:
33+
assert isinstance(node.target, str)
34+
sub_gm = top_level.get_submodule(node.target)
35+
assert isinstance(sub_gm, fx.GraphModule)
36+
37+
replace_res = _remove_duplicate_output_args(sub_gm)
38+
processed_subnets[node.name] = replace_res
39+
if replace_res.replacement_map is None:
40+
continue
41+
sub_gm.recompile()
42+
43+
needs_recompile = False
44+
# iterate on the copy since we will be changing elements of node.users
45+
for user in list(node.users):
46+
idx = _ensure_proper_output_use(user, node)
47+
idx_new = replace_res.replacement_map[idx]
48+
if idx_new != idx:
49+
user.args = (user.args[0], idx_new)
50+
needs_recompile = True
51+
52+
if needs_recompile:
53+
top_level.recompile()
54+
return processed_subnets
55+
56+
57+
@dc.dataclass(frozen=True)
58+
class RemoveDuplicateResult:
59+
replacement_map: t.Optional[t.List[int]]
60+
module: fx.GraphModule
61+
62+
63+
def _ensure_proper_output_use(user: fx.Node, target_node: fx.Node) -> int:
64+
"""
65+
Ensures the node looks in proper form of calling the output of an fx2trt
66+
splitter sub-net. Specifically:
67+
68+
1. op is call function, target: operator.getitem
69+
2. args is a 2-element tuple
70+
3. args[0] is the name of the subnet's output
71+
4. args[1] is the index into the subnet output tuple
72+
73+
E.g.:
74+
75+
%getitem_4 : [#users=1] = call_function[target=operator.getitem](args = (%_run_on_acc_1, 4), kwargs = {})
76+
77+
returns the index into the subnet output tuple
78+
"""
79+
_LOGGER.info(f"Checking user node: {user.format_node()}")
80+
assert (
81+
user.op == "call_function"
82+
and user.target == operator.getitem
83+
and len(user.args) == 2
84+
and isinstance(user.args[0], fx.Node)
85+
and user.args[0].name == target_node.name
86+
and isinstance(user.args[1], int)
87+
), f"Node is not a proper user of splitter output: {user.format_node()}"
88+
89+
return user.args[1]
90+
91+
92+
def _remove_duplicate_output_args(gm: fx.GraphModule) -> RemoveDuplicateResult:
93+
output_nodes = [n for n in gm.graph.nodes if n.op == "output"]
94+
assert len(output_nodes) == 1, \
95+
f"Expecting exactly one `output` node, but got {len(output_nodes)}"
96+
97+
changed = False
98+
# arg node name to its index in the new output args tuple
99+
name_to_idx: t.Dict[str, int] = {}
100+
output_node = output_nodes[0]
101+
102+
# Output op only uses its `args[0]`, and it does not have `kwargs`.
103+
# https://pytorch.org/docs/stable/fx.html#torch.fx.Node
104+
args: t.Sequence[t.Any] = output_node.args[0]
105+
106+
# Only concern outselves to the case where the args is an iterable of fx.Node.
107+
# Other return cases (e.g., a single value) is possible and we don't handle
108+
# that in this pass.
109+
if not (isinstance(args, t.Iterable) and all(isinstance(a, fx.Node) for a in args)):
110+
return RemoveDuplicateResult(replacement_map=None, module=gm)
111+
112+
# Map old index of the arg node to the remaining node's idx,
113+
# initialized to `i => i`
114+
replacement_map: t.List[int] = list(range(len(args)))
115+
args_new = []
116+
for idx, a in enumerate(args):
117+
assert isinstance(a, fx.Node), \
118+
f"Expecting fx.Node instance, but got: {type(a)}"
119+
120+
if a.name not in name_to_idx:
121+
args_new.append(a)
122+
name_to_idx[a.name] = len(args_new) - 1
123+
else:
124+
changed = True
125+
_LOGGER.warning(
126+
f"Replaced duplicate output arg '{a.name}': "
127+
f"{idx} -> {name_to_idx[a.name]}"
128+
)
129+
replacement_map[idx] = name_to_idx[a.name]
130+
131+
output_node.args = (tuple(args_new),)
132+
if changed:
133+
gm.recompile()
134+
return RemoveDuplicateResult(replacement_map, module=gm)

0 commit comments

Comments
 (0)