1
1
import dataclasses
2
2
import inspect
3
- import weakref
4
3
import re
4
+ import weakref
5
5
from collections import OrderedDict
6
- from typing import Any , Callable , List , Tuple , Optional , Dict , Union
6
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
7
7
8
8
import sympy
9
9
10
10
import torch
11
11
import torch ._dynamo
12
12
import torch .fx
13
- from torch .fx .graph import _PyTreeCodeGen , _PyTreeInfo
14
- from .exported_program import (
15
- CallSpec ,
16
- ExportedProgram ,
17
- ExportBackwardSignature ,
18
- ExportGraphSignature ,
19
- _process_constraints ,
20
- )
21
- from .passes .replace_sym_size_ops_pass import _ReplaceSymSizeOpPass
22
- from torch ._decomp import core_aten_decompositions
13
+
14
+ import torch .utils ._pytree as pytree
15
+ from torch ._decomp import core_aten_decompositions , get_decompositions
16
+ from torch ._dispatch .python import enable_python_dispatcher
23
17
from torch ._dynamo .eval_frame import Constraint
18
+ from torch ._dynamo .exc import UserError , UserErrorType
24
19
from torch ._functorch .aot_autograd import aot_export_module
20
+ from torch ._functorch .eager_transforms import functionalize
25
21
from torch ._guards import detect_fake_mode
26
- from torch ._subclasses .fake_tensor import FakeTensor , FakeTensorMode
27
-
28
- import torch .utils . _pytree as pytree
22
+ from torch ._subclasses .fake_tensor import FakeTensorMode
23
+ from torch . fx import traceback as fx_traceback
24
+ from torch .fx . experimental . proxy_tensor import make_fx
29
25
from torch .fx .experimental .symbolic_shapes import (
30
26
ConstraintViolationError ,
31
27
GuardOnDataDependentSymNode ,
32
28
ShapeEnv ,
33
29
StrictMinMaxConstraint ,
34
30
)
31
+ from torch .fx .graph import _PyTreeCodeGen , _PyTreeInfo
32
+ from torch .utils ._sympy .value_ranges import ValueRangeError , ValueRanges
35
33
36
- from torch ._dynamo .exc import UserError , UserErrorType
37
- from torch .utils ._sympy .value_ranges import ValueRanges , ValueRangeError
38
-
34
+ from .exported_program import (
35
+ _process_constraints ,
36
+ CallSpec ,
37
+ ExportBackwardSignature ,
38
+ ExportedProgram ,
39
+ ExportGraphSignature ,
40
+ )
41
+ from .passes .replace_sym_size_ops_pass import _ReplaceSymSizeOpPass
39
42
40
43
41
44
# Note - [On Export Dynamic Dimension UX]
@@ -156,6 +159,7 @@ def export(
156
159
* args ,
157
160
constraints = constraints ,
158
161
assume_static_by_default = True ,
162
+ tracing_mode = "symbolic" ,
159
163
)
160
164
161
165
params_buffers : "OrderedDict[str, Union[torch.Tensor, torch.nn.Parameter]]" = OrderedDict ()
@@ -190,18 +194,17 @@ def export(
190
194
params_buffers_to_node_meta [n .target ] = meta
191
195
192
196
fake_inps = []
193
- for node in gm_torch_level .graph .nodes :
194
- if node .op == "placeholder" and "val" in node .meta :
195
- fake_val = node .meta ["val" ]
196
- fake_inps .append (fake_val )
197
-
198
197
fake_mode = FakeTensorMode (
199
198
allow_fallback_kernels = False ,
200
199
allow_non_fake_inputs = True ,
201
200
shape_env = ShapeEnv (
202
201
assume_static_by_default = True ,
203
202
),
204
203
)
204
+ for node in gm_torch_level .graph .nodes :
205
+ if node .op == "placeholder" and "val" in node .meta :
206
+ fake_val = node .meta ["val" ]
207
+ fake_inps .append (fake_val )
205
208
206
209
if detected_fake_mode := detect_fake_mode (fake_inps ):
207
210
fake_mode = detected_fake_mode
@@ -228,6 +231,26 @@ def export(
228
231
)
229
232
230
233
gm_torch_level .recompile ()
234
+
235
+ params_buffers_to_node_meta = OrderedDict ()
236
+
237
+ for node in gm_torch_level .graph .nodes :
238
+ target = node .target
239
+ meta = node .meta
240
+ if node .op == "call_module" :
241
+ submodule = getattr (gm_torch_level , target )
242
+ if isinstance (submodule , torch .nn .Module ):
243
+ for name , _ in submodule .named_parameters (recurse = True , remove_duplicate = False ):
244
+ params_buffers_to_node_meta [target + "." + name ] = meta
245
+
246
+ for name , _ in submodule .named_buffers (recurse = True , remove_duplicate = False ):
247
+ params_buffers_to_node_meta [target + "." + name ] = meta
248
+
249
+ if node .op == "call_function" and not isinstance (node .target , torch ._ops .HigherOrderOperator ):
250
+ for n in node ._input_nodes :
251
+ if n .op == "get_attr" :
252
+ params_buffers_to_node_meta [n .target ] = meta
253
+
231
254
gm , graph_signature = aot_export_module (gm_torch_level , fake_args , decompositions = DECOMP_TABLE , trace_joint = False )
232
255
233
256
export_backward_signature = ExportBackwardSignature (
0 commit comments