Skip to content

Commit 5666d20

Browse files
tugsbayasgalanpytorchmergebot
authored andcommittedJul 19, 2023
Add unlifting pass under private config (pytorch#104897)
Summary: We wanna do this little by little. For now, I tried only on DissectedPartsModel which needs to use aot_export version. Test Plan: CI Reviewed By: zhxchen17 Differential Revision: D46785735 Pull Request resolved: pytorch#104897 Approved by: https://github.com/JacobSzwejbka
1 parent fbd7e74 commit 5666d20

File tree

3 files changed

+55
-26
lines changed

3 files changed

+55
-26
lines changed
 

‎torch/_export/__init__.py

+46-23
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,44 @@
11
import dataclasses
22
import inspect
3-
import weakref
43
import re
4+
import weakref
55
from collections import OrderedDict
6-
from typing import Any, Callable, List, Tuple, Optional, Dict, Union
6+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
77

88
import sympy
99

1010
import torch
1111
import torch._dynamo
1212
import torch.fx
13-
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
14-
from .exported_program import (
15-
CallSpec,
16-
ExportedProgram,
17-
ExportBackwardSignature,
18-
ExportGraphSignature,
19-
_process_constraints,
20-
)
21-
from .passes.replace_sym_size_ops_pass import _ReplaceSymSizeOpPass
22-
from torch._decomp import core_aten_decompositions
13+
14+
import torch.utils._pytree as pytree
15+
from torch._decomp import core_aten_decompositions, get_decompositions
16+
from torch._dispatch.python import enable_python_dispatcher
2317
from torch._dynamo.eval_frame import Constraint
18+
from torch._dynamo.exc import UserError, UserErrorType
2419
from torch._functorch.aot_autograd import aot_export_module
20+
from torch._functorch.eager_transforms import functionalize
2521
from torch._guards import detect_fake_mode
26-
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
27-
28-
import torch.utils._pytree as pytree
22+
from torch._subclasses.fake_tensor import FakeTensorMode
23+
from torch.fx import traceback as fx_traceback
24+
from torch.fx.experimental.proxy_tensor import make_fx
2925
from torch.fx.experimental.symbolic_shapes import (
3026
ConstraintViolationError,
3127
GuardOnDataDependentSymNode,
3228
ShapeEnv,
3329
StrictMinMaxConstraint,
3430
)
31+
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
32+
from torch.utils._sympy.value_ranges import ValueRangeError, ValueRanges
3533

36-
from torch._dynamo.exc import UserError, UserErrorType
37-
from torch.utils._sympy.value_ranges import ValueRanges, ValueRangeError
38-
34+
from .exported_program import (
35+
_process_constraints,
36+
CallSpec,
37+
ExportBackwardSignature,
38+
ExportedProgram,
39+
ExportGraphSignature,
40+
)
41+
from .passes.replace_sym_size_ops_pass import _ReplaceSymSizeOpPass
3942

4043

4144
# Note - [On Export Dynamic Dimension UX]
@@ -156,6 +159,7 @@ def export(
156159
*args,
157160
constraints=constraints,
158161
assume_static_by_default=True,
162+
tracing_mode="symbolic",
159163
)
160164

161165
params_buffers: "OrderedDict[str, Union[torch.Tensor, torch.nn.Parameter]]" = OrderedDict()
@@ -190,18 +194,17 @@ def export(
190194
params_buffers_to_node_meta[n.target] = meta
191195

192196
fake_inps = []
193-
for node in gm_torch_level.graph.nodes:
194-
if node.op == "placeholder" and "val" in node.meta:
195-
fake_val = node.meta["val"]
196-
fake_inps.append(fake_val)
197-
198197
fake_mode = FakeTensorMode(
199198
allow_fallback_kernels=False,
200199
allow_non_fake_inputs=True,
201200
shape_env=ShapeEnv(
202201
assume_static_by_default=True,
203202
),
204203
)
204+
for node in gm_torch_level.graph.nodes:
205+
if node.op == "placeholder" and "val" in node.meta:
206+
fake_val = node.meta["val"]
207+
fake_inps.append(fake_val)
205208

206209
if detected_fake_mode := detect_fake_mode(fake_inps):
207210
fake_mode = detected_fake_mode
@@ -228,6 +231,26 @@ def export(
228231
)
229232

230233
gm_torch_level.recompile()
234+
235+
params_buffers_to_node_meta = OrderedDict()
236+
237+
for node in gm_torch_level.graph.nodes:
238+
target = node.target
239+
meta = node.meta
240+
if node.op == "call_module":
241+
submodule = getattr(gm_torch_level, target)
242+
if isinstance(submodule, torch.nn.Module):
243+
for name, _ in submodule.named_parameters(recurse=True, remove_duplicate=False):
244+
params_buffers_to_node_meta[target + "." + name] = meta
245+
246+
for name, _ in submodule.named_buffers(recurse=True, remove_duplicate=False):
247+
params_buffers_to_node_meta[target + "." + name] = meta
248+
249+
if node.op == "call_function" and not isinstance(node.target, torch._ops.HigherOrderOperator):
250+
for n in node._input_nodes:
251+
if n.op == "get_attr":
252+
params_buffers_to_node_meta[n.target] = meta
253+
231254
gm, graph_signature = aot_export_module(gm_torch_level, fake_args, decompositions=DECOMP_TABLE, trace_joint=False)
232255

233256
export_backward_signature = ExportBackwardSignature(

‎torch/ao/quantization/pt2e/quantizer/utils.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,13 @@ def _annotate_output_qspec(node: Node, qspec):
9090

9191

9292
def _is_sym_size_node(node: Node):
93-
return node.op == "call_function" and node.target == torch.ops.aten.sym_size
93+
return (
94+
node.op == "call_function" and
95+
node.target == torch.ops.aten.sym_size.default or
96+
node.target == torch.ops.aten.sym_numel.default or
97+
node.target == torch.ops.aten.sym_numel or
98+
node.target == torch.ops.aten.sym_size
99+
)
94100

95101

96102
def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]):

‎torch/fx/graph.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ def process_inputs(self, *inputs: Any) -> Any:
612612
return flat_args
613613

614614
def process_outputs(self, out: Any) -> Any:
615-
if self.pytree_info is None:
615+
if self.pytree_info is None or self.pytree_info.out_spec is None:
616616
return out
617617
if not isinstance(out, list):
618618
out = [out]
@@ -665,7 +665,7 @@ def gen_fn_def(self, free_vars, maybe_return_annotation):
665665
return fn_definition
666666

667667
def generate_output(self, output_args):
668-
if self.pytree_info:
668+
if self.pytree_info and self.pytree_info.out_spec:
669669
return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)'
670670
else:
671671
return super().generate_output(output_args)

0 commit comments

Comments
 (0)
Please sign in to comment.