Skip to content

Commit 0435af5

Browse files
Comments
1 parent aa95459 commit 0435af5

File tree

2 files changed

+37
-43
lines changed

2 files changed

+37
-43
lines changed

nncf/experimental/torch/fx/constant_folding.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -248,30 +248,29 @@ def constant_fold(
248248
:param constraint_fn: Constraint function which takes a node and returs the constraint:
249249
should the node be constant folded or not.
250250
"""
251-
with torch.no_grad():
252-
with torch.utils._python_dispatch._disable_current_modes():
253-
cf = ConstantFolder(gm)
254-
cf.run()
255-
256-
device = get_model_device(gm)
257-
for node, constant in cf.node_replacements.items():
258-
if constraint_fn is not None and not constraint_fn(node):
259-
continue
260-
constant = constant.to(device)
261-
_replace_node_with_constant(gm, node, constant)
262-
263-
erased_params = []
264-
for node in gm.graph.find_nodes(op="get_attr"):
265-
if len(node.users) == 0:
266-
if hasattr(gm, node.target):
267-
delattr(gm, node.target)
268-
erased_params.append(node)
269-
270-
for node in erased_params:
271-
gm.graph.erase_node(node)
272-
273-
# Custom _is_impure function allows to eliminate all layers with zero
274-
# users including inplace ops like relu_ besides output and placeholders.
275-
gm.graph.eliminate_dead_code(_is_impure)
276-
gm.graph.lint()
277-
gm.recompile()
251+
with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes():
252+
cf = ConstantFolder(gm)
253+
cf.run()
254+
255+
device = get_model_device(gm)
256+
for node, constant in cf.node_replacements.items():
257+
if constraint_fn is not None and not constraint_fn(node):
258+
continue
259+
constant = constant.to(device)
260+
_replace_node_with_constant(gm, node, constant)
261+
262+
erased_params = []
263+
for node in gm.graph.find_nodes(op="get_attr"):
264+
if len(node.users) == 0:
265+
if hasattr(gm, node.target):
266+
delattr(gm, node.target)
267+
erased_params.append(node)
268+
269+
for node in erased_params:
270+
gm.graph.erase_node(node)
271+
272+
# Custom _is_impure function allows to eliminate all layers with zero
273+
# users including inplace ops like relu_ besides output and placeholders.
274+
gm.graph.eliminate_dead_code(_is_impure)
275+
gm.graph.lint()
276+
gm.recompile()
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
11
strict digraph {
22
"0 linear_weight" [id=0, type=get_attr];
33
"1 linear_bias" [id=1, type=get_attr];
4-
"2 lifted_tensor_0" [id=2, type=get_attr];
5-
"3 x" [id=3, type=input];
6-
"4 lift_fresh_copy" [id=4, type=lift_fresh_copy];
7-
"5 detach_" [id=5, type=detach_];
8-
"6 _frozen_param0" [id=6, type=get_attr];
9-
"7 linear" [id=7, type=linear];
10-
"8 add" [id=8, type=add];
11-
"9 output" [id=9, type=output];
12-
"0 linear_weight" -> "7 linear" [label="(3, 3)", style=solid];
13-
"1 linear_bias" -> "7 linear" [label="(3,)", style=solid];
14-
"2 lifted_tensor_0" -> "4 lift_fresh_copy" [label="()", style=solid];
15-
"3 x" -> "7 linear" [label="(1, 3, 3, 3)", style=solid];
16-
"4 lift_fresh_copy" -> "5 detach_" [label="()", style=solid];
17-
"6 _frozen_param0" -> "8 add" [label="()", style=solid];
18-
"7 linear" -> "8 add" [label="(1, 3, 3, 3)", style=solid];
19-
"8 add" -> "9 output" [label="(1, 3, 3, 3)", style=solid];
4+
"2 x" [id=2, type=input];
5+
"3 _frozen_param0" [id=3, type=get_attr];
6+
"4 linear" [id=4, type=linear];
7+
"5 add" [id=5, type=add];
8+
"6 output" [id=6, type=output];
9+
"0 linear_weight" -> "4 linear" [label="(3, 3)", style=solid];
10+
"1 linear_bias" -> "4 linear" [label="(3,)", style=solid];
11+
"2 x" -> "4 linear" [label="(1, 3, 3, 3)", style=solid];
12+
"3 _frozen_param0" -> "5 add" [label="()", style=solid];
13+
"4 linear" -> "5 add" [label="(1, 3, 3, 3)", style=solid];
14+
"5 add" -> "6 output" [label="(1, 3, 3, 3)", style=solid];
2015
}

0 commit comments

Comments
 (0)