Skip to content

Commit 89c1cfe

Browse files
angelayipytorchmergebot
authored andcommitted
[export] Allow modules to be created in the forward (pytorch#125725)
Fixes the error in non-strict export when we're tracing a module that initializes another module in its forward function. This appears in [many huggingface models](https://github.com/search?q=repo%3Ahuggingface%2Ftransformers+CrossEntropyLoss%28%29&type=code&fbclid=IwAR285uKvSevJM6SDbXmb4-monj4iH7wf8opkvnec-li7sKpn4lUMjIvbGKc). It's probably not good practice to do this, but since it appears in so many places, and strict-export supports this, we will also support this. The approach we'll take for these cases is that we will inline the call to the module. Parameters and buffers initialized as constants (with `torch.tensor`) will be represented as constant tensors, and those initialized with tensor factory functions (`torch.ones`) will show up as an operator in the graph. The module stack for the ops in the inlined module will reflect the toplevel's module stack. One issue is that strict-export seems to segfault when there is an `nn.Parameter` call in the constructor (pytorch#126109). Non-strict export will succeed. Pull Request resolved: pytorch#125725 Approved by: https://github.com/ydwu4
1 parent 6550386 commit 89c1cfe

File tree

2 files changed

+174
-2
lines changed

2 files changed

+174
-2
lines changed

test/export/test_export.py

+154
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ class Inp:
148148

149149
NON_STRICT_SUFFIX = "_non_strict"
150150
RETRACEABILITY_SUFFIX = "_retraceability"
151+
PREDISPATCH_SUFFIX = "_pre_dispatch"
151152

152153

153154
def is_non_strict_test(test_name):
@@ -3279,6 +3280,159 @@ def dynamify_inp(x):
32793280
with self.assertRaisesRegex(RuntimeError, "shape\[0\] to be >= 3, but got 2"):
32803281
ep.module()(*test_inp)
32813282

3283+
def test_nested_module(self):
3284+
class M1(torch.nn.Module):
3285+
def forward(self, x):
3286+
return x + x
3287+
3288+
class M2(torch.nn.Module):
3289+
def forward(self, x):
3290+
m = M1()
3291+
return m(x) * x
3292+
3293+
inps = (torch.randn(3, 3),)
3294+
ep = export(M2(), inps)
3295+
self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps)))
3296+
3297+
add_nodes = [
3298+
node
3299+
for node in ep.graph.nodes
3300+
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor
3301+
]
3302+
self.assertEqual(len(add_nodes), 1)
3303+
add_node = add_nodes[0]
3304+
self.assertEqual(len(add_node.meta["nn_module_stack"]), 1)
3305+
self.assertTrue("M2" in list(add_node.meta["nn_module_stack"].values())[0][1])
3306+
3307+
self.assertExpectedInline(
3308+
str(ep.graph).strip(),
3309+
"""\
3310+
graph():
3311+
%x : [num_users=2] = placeholder[target=x]
3312+
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
3313+
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
3314+
return (mul,)""",
3315+
)
3316+
3317+
unflattened = unflatten(ep)
3318+
self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps)))
3319+
3320+
def test_nested_module_with_init_buffer(self):
3321+
class M1(torch.nn.Module):
3322+
def __init__(self):
3323+
super().__init__()
3324+
self.b = torch.ones(3, 3)
3325+
3326+
def forward(self, x):
3327+
return x + self.b
3328+
3329+
class M2(torch.nn.Module):
3330+
def forward(self, x):
3331+
m = M1()
3332+
return m(x) * x
3333+
3334+
inps = (torch.randn(3, 3),)
3335+
ep = export(M2(), inps)
3336+
self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps)))
3337+
3338+
self.assertEqual(len(ep.state_dict), 0)
3339+
self.assertEqual(len(ep.constants), 0)
3340+
3341+
self.assertExpectedInline(
3342+
str(ep.graph).strip(),
3343+
"""\
3344+
graph():
3345+
%x : [num_users=2] = placeholder[target=x]
3346+
%ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False})
3347+
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %ones), kwargs = {})
3348+
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
3349+
return (mul,)""",
3350+
)
3351+
3352+
unflattened = unflatten(ep)
3353+
self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps)))
3354+
3355+
@testing.expectedFailureRetraceability # Retracing tensor constants results in buffers
3356+
def test_nested_module_with_constant_buffer(self):
3357+
class M1(torch.nn.Module):
3358+
def __init__(self):
3359+
super().__init__()
3360+
self.b = torch.tensor(5)
3361+
3362+
def forward(self, x):
3363+
return x + self.b
3364+
3365+
class M2(torch.nn.Module):
3366+
def forward(self, x):
3367+
m = M1()
3368+
return m(x) * x
3369+
3370+
inps = (torch.randn(3, 3),)
3371+
ep = export(M2(), inps)
3372+
self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps)))
3373+
3374+
self.assertEqual(len(ep.state_dict), 0)
3375+
self.assertEqual(len(ep.constants), 1)
3376+
3377+
self.assertExpectedInline(
3378+
str(ep.graph).strip(),
3379+
"""\
3380+
graph():
3381+
%c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
3382+
%x : [num_users=2] = placeholder[target=x]
3383+
%lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
3384+
%detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {})
3385+
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %detach), kwargs = {})
3386+
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
3387+
return (mul,)""",
3388+
)
3389+
3390+
unflattened = unflatten(ep)
3391+
self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps)))
3392+
3393+
def test_nested_module_with_parameter(self):
3394+
class M1(torch.nn.Module):
3395+
def __init__(self):
3396+
super().__init__()
3397+
self.a = torch.nn.Parameter(torch.ones(3, 3))
3398+
self.b = torch.nn.Parameter(torch.tensor(5.0))
3399+
3400+
def forward(self, x):
3401+
return x + self.a * self.b
3402+
3403+
class M2(torch.nn.Module):
3404+
def forward(self, x):
3405+
m = M1()
3406+
return m(x) * x
3407+
3408+
inps = (torch.randn(3, 3),)
3409+
# Strict export segfaults (Issue #128109)
3410+
ep = torch.export.export(M2(), inps, strict=False)
3411+
self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps)))
3412+
3413+
self.assertEqual(len(ep.state_dict), 0)
3414+
self.assertEqual(len(ep.constants), 1)
3415+
3416+
self.assertExpectedInline(
3417+
str(ep.graph).strip(),
3418+
"""\
3419+
graph():
3420+
%c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
3421+
%x : [num_users=2] = placeholder[target=x]
3422+
%ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False})
3423+
%detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%ones,), kwargs = {})
3424+
%lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
3425+
%detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {})
3426+
%detach_2 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_1,), kwargs = {})
3427+
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach, %detach_2), kwargs = {})
3428+
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {})
3429+
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
3430+
return (mul_1,)""",
3431+
)
3432+
3433+
unflattened = unflatten(ep)
3434+
self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps)))
3435+
32823436
def test_lazy_module_kwargs(self):
32833437
class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
32843438
def initialize_parameters(self, *args, **kwargs):

torch/fx/experimental/proxy_tensor.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from torch.utils._traceback import CapturedTraceback
3030
import logging
3131
from torch._library.fake_class_registry import FakeScriptObject
32+
import warnings
3233

3334
from torch.overrides import TorchFunctionMode
3435

@@ -921,6 +922,10 @@ def disable_autocast_cache():
921922
torch.set_autocast_cache_enabled(old_value)
922923

923924

925+
class ModuleNotInstalledAsSubmoduleError(NameError):
926+
pass
927+
928+
924929
class _ModuleStackTracer(PythonKeyTracer):
925930
r"""Customized version of PythonKeyTracer that retains module stack
926931
information in node.meta["nn_module_stack"].
@@ -998,7 +1003,10 @@ def path_of_module(self, mod: torch.nn.Module) -> str:
9981003
if isinstance(mod, self.proxy_type):
9991004
return self.proxy_paths[mod]
10001005

1001-
return Tracer.path_of_module(self, mod)
1006+
try:
1007+
return Tracer.path_of_module(self, mod)
1008+
except NameError as e:
1009+
raise ModuleNotInstalledAsSubmoduleError from e
10021010

10031011
def getattr(self, attr, attr_val, parameter_proxy_cache):
10041012
if not isinstance(attr_val, torch.nn.Module) or isinstance(attr_val, torch.fx.GraphModule):
@@ -1070,7 +1078,17 @@ def call_module(self, m, forward, args, kwargs):
10701078
# use cases don't need to work with HOO.
10711079
if isinstance(m, (OptimizedModule, GraphModule)):
10721080
return forward(*args, **kwargs)
1073-
return Tracer.call_module(self, m, forward, args, kwargs)
1081+
1082+
try:
1083+
return Tracer.call_module(self, m, forward, args, kwargs)
1084+
except ModuleNotInstalledAsSubmoduleError as e:
1085+
warnings.warn(
1086+
f"Unable to find the path of the module {m}. "
1087+
"This might be because the module was not properly registered "
1088+
"as a submodule, which is not good practice. We will trace "
1089+
"through the module without recording stack information."
1090+
)
1091+
return forward(*args, **kwargs)
10741092

10751093

10761094
def is_leaf_module(self, m, module_qualified_name):

0 commit comments

Comments
 (0)