Skip to content

Commit 48b55ca

Browse files
yiming0416pytorchmergebot
authored andcommitted
[export] Fix non-strict retracing with kwargs (pytorch#138927)
Summary: `torch.fx.Interpreter.run()` only takes args as input. Currently we pass kwargs as well which causes errors during retracing. Flatten the kwargs and concat them with args will solve the issue. Several previously failing tests under `_retraceability_non_strict` now passes. Test Plan: ``` buck2 test @//mode/dev-nosan //caffe2/test:test_export -- -r _retraceability_non_strict ``` Differential Revision: D64980053 Pull Request resolved: pytorch#138927 Approved by: https://github.com/angelayi
1 parent 3342b53 commit 48b55ca

File tree

2 files changed

+2
-9
lines changed

2 files changed

+2
-9
lines changed

test/export/test_export.py

-8
Original file line numberDiff line numberDiff line change
@@ -2431,7 +2431,6 @@ def forward(self, x, y, z):
24312431
if node.op == "placeholder":
24322432
self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
24332433

2434-
@testing.expectedFailureRetraceabilityNonStrict
24352434
def test_dynamic_shapes_builder_kwargs(self):
24362435
class M(torch.nn.Module):
24372436
def forward(self, x, y, z):
@@ -3129,7 +3128,6 @@ def forward(self, image, crop_height, crop_width):
31293128
args = (torch.rand(3, 700, 700), 150, 150)
31303129
self.assertEqual(ecrop.module()(*args), ecrop(*args))
31313130

3132-
@testing.expectedFailureRetraceabilityNonStrict
31333131
def test_export_func_with_kwargs(self):
31343132
class Module(torch.nn.Module):
31353133
def forward(self, arg1, arg2, kw1, kw2):
@@ -3140,7 +3138,6 @@ def forward(self, arg1, arg2, kw1, kw2):
31403138
kwargs = {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)}
31413139
self._test_export_same_as_eager(kw_func, args, kwargs)
31423140

3143-
@testing.expectedFailureRetraceabilityNonStrict
31443141
def test_export_func_with_pytree_kwargs(self):
31453142
class Module(torch.nn.Module):
31463143
def forward(self, arg1, arg2, a, b):
@@ -3154,7 +3151,6 @@ def forward(self, arg1, arg2, a, b):
31543151
}
31553152
self._test_export_same_as_eager(kw_func, args, kwargs)
31563153

3157-
@testing.expectedFailureRetraceabilityNonStrict
31583154
def test_export_func_with_default_kwargs(self):
31593155
class Module(torch.nn.Module):
31603156
def forward(self, arg1, arg2, a, b=1):
@@ -3185,7 +3181,6 @@ def forward(self, arg1, arg2, *args):
31853181
args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4))
31863182
self._test_export_same_as_eager(kw_func, args)
31873183

3188-
@testing.expectedFailureRetraceabilityNonStrict
31893184
def test_export_func_with_keyword_only_args(self):
31903185
class Module(torch.nn.Module):
31913186
def forward(self, arg1, arg2, *args, kw1, kw2):
@@ -3196,7 +3191,6 @@ def forward(self, arg1, arg2, *args, kw1, kw2):
31963191
kwargs = {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4)}
31973192
self._test_export_same_as_eager(kw_func, args, kwargs)
31983193

3199-
@testing.expectedFailureRetraceabilityNonStrict
32003194
def test_export_func_with_var_keyword_args(self):
32013195
class Module(torch.nn.Module):
32023196
def forward(self, arg1, arg2, *args, kw1, kw2, **kwargs):
@@ -3291,7 +3285,6 @@ def forward(self, x, y):
32913285
self.assertTrue(torch.allclose(orig_res[1], ep_res[1]))
32923286
self.assertTrue(torch.allclose(orig_res[2], ep_res[2]))
32933287

3294-
@testing.expectedFailureRetraceabilityNonStrict
32953288
def test_export_func_with_var_keyword_pytree_args(self):
32963289
class Module(torch.nn.Module):
32973290
def forward(self, arg1, arg2, *args, kw1, kw2, **kwargs):
@@ -5656,7 +5649,6 @@ def forward(self, x):
56565649
unflattened = unflatten(ep)
56575650
self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps)))
56585651

5659-
@testing.expectedFailureRetraceabilityNonStrict
56605652
def test_lazy_module_kwargs(self):
56615653
class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
56625654
def initialize_parameters(self, *args, **kwargs):

torch/export/_trace.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1641,7 +1641,8 @@ def forward(self, *args, **kwargs):
16411641
):
16421642
_check_input_constraints_pre_hook(mod, args, kwargs)
16431643
with torch.fx.traceback.preserve_node_meta():
1644-
tree_out = torch.fx.Interpreter(mod).run(*args, **kwargs)
1644+
args = (*args, *kwargs.values())
1645+
tree_out = torch.fx.Interpreter(mod).run(*args)
16451646
else:
16461647
tree_out = mod(*args, **kwargs)
16471648
flat_outs, out_spec = pytree.tree_flatten(tree_out)

0 commit comments

Comments
 (0)