@@ -148,6 +148,7 @@ class Inp:
148
148
149
149
NON_STRICT_SUFFIX = "_non_strict"
150
150
RETRACEABILITY_SUFFIX = "_retraceability"
151
+ PREDISPATCH_SUFFIX = "_pre_dispatch"
151
152
152
153
153
154
def is_non_strict_test (test_name ):
@@ -3279,6 +3280,159 @@ def dynamify_inp(x):
3279
3280
with self .assertRaisesRegex (RuntimeError , "shape\[0\] to be >= 3, but got 2" ):
3280
3281
ep .module ()(* test_inp )
3281
3282
3283
+ def test_nested_module (self ):
3284
+ class M1 (torch .nn .Module ):
3285
+ def forward (self , x ):
3286
+ return x + x
3287
+
3288
+ class M2 (torch .nn .Module ):
3289
+ def forward (self , x ):
3290
+ m = M1 ()
3291
+ return m (x ) * x
3292
+
3293
+ inps = (torch .randn (3 , 3 ),)
3294
+ ep = export (M2 (), inps )
3295
+ self .assertTrue (torch .allclose (ep .module ()(* inps ), M2 ()(* inps )))
3296
+
3297
+ add_nodes = [
3298
+ node
3299
+ for node in ep .graph .nodes
3300
+ if node .op == "call_function" and node .target == torch .ops .aten .add .Tensor
3301
+ ]
3302
+ self .assertEqual (len (add_nodes ), 1 )
3303
+ add_node = add_nodes [0 ]
3304
+ self .assertEqual (len (add_node .meta ["nn_module_stack" ]), 1 )
3305
+ self .assertTrue ("M2" in list (add_node .meta ["nn_module_stack" ].values ())[0 ][1 ])
3306
+
3307
+ self .assertExpectedInline (
3308
+ str (ep .graph ).strip (),
3309
+ """\
3310
+ graph():
3311
+ %x : [num_users=2] = placeholder[target=x]
3312
+ %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
3313
+ %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
3314
+ return (mul,)""" ,
3315
+ )
3316
+
3317
+ unflattened = unflatten (ep )
3318
+ self .assertTrue (torch .allclose (unflattened (* inps ), M2 ()(* inps )))
3319
+
3320
+ def test_nested_module_with_init_buffer (self ):
3321
+ class M1 (torch .nn .Module ):
3322
+ def __init__ (self ):
3323
+ super ().__init__ ()
3324
+ self .b = torch .ones (3 , 3 )
3325
+
3326
+ def forward (self , x ):
3327
+ return x + self .b
3328
+
3329
+ class M2 (torch .nn .Module ):
3330
+ def forward (self , x ):
3331
+ m = M1 ()
3332
+ return m (x ) * x
3333
+
3334
+ inps = (torch .randn (3 , 3 ),)
3335
+ ep = export (M2 (), inps )
3336
+ self .assertTrue (torch .allclose (ep .module ()(* inps ), M2 ()(* inps )))
3337
+
3338
+ self .assertEqual (len (ep .state_dict ), 0 )
3339
+ self .assertEqual (len (ep .constants ), 0 )
3340
+
3341
+ self .assertExpectedInline (
3342
+ str (ep .graph ).strip (),
3343
+ """\
3344
+ graph():
3345
+ %x : [num_users=2] = placeholder[target=x]
3346
+ %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False})
3347
+ %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %ones), kwargs = {})
3348
+ %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
3349
+ return (mul,)""" ,
3350
+ )
3351
+
3352
+ unflattened = unflatten (ep )
3353
+ self .assertTrue (torch .allclose (unflattened (* inps ), M2 ()(* inps )))
3354
+
3355
+ @testing .expectedFailureRetraceability # Retracing tensor constants results in buffers
3356
+ def test_nested_module_with_constant_buffer (self ):
3357
+ class M1 (torch .nn .Module ):
3358
+ def __init__ (self ):
3359
+ super ().__init__ ()
3360
+ self .b = torch .tensor (5 )
3361
+
3362
+ def forward (self , x ):
3363
+ return x + self .b
3364
+
3365
+ class M2 (torch .nn .Module ):
3366
+ def forward (self , x ):
3367
+ m = M1 ()
3368
+ return m (x ) * x
3369
+
3370
+ inps = (torch .randn (3 , 3 ),)
3371
+ ep = export (M2 (), inps )
3372
+ self .assertTrue (torch .allclose (ep .module ()(* inps ), M2 ()(* inps )))
3373
+
3374
+ self .assertEqual (len (ep .state_dict ), 0 )
3375
+ self .assertEqual (len (ep .constants ), 1 )
3376
+
3377
+ self .assertExpectedInline (
3378
+ str (ep .graph ).strip (),
3379
+ """\
3380
+ graph():
3381
+ %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
3382
+ %x : [num_users=2] = placeholder[target=x]
3383
+ %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
3384
+ %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {})
3385
+ %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %detach), kwargs = {})
3386
+ %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
3387
+ return (mul,)""" ,
3388
+ )
3389
+
3390
+ unflattened = unflatten (ep )
3391
+ self .assertTrue (torch .allclose (unflattened (* inps ), M2 ()(* inps )))
3392
+
3393
+ def test_nested_module_with_parameter (self ):
3394
+ class M1 (torch .nn .Module ):
3395
+ def __init__ (self ):
3396
+ super ().__init__ ()
3397
+ self .a = torch .nn .Parameter (torch .ones (3 , 3 ))
3398
+ self .b = torch .nn .Parameter (torch .tensor (5.0 ))
3399
+
3400
+ def forward (self , x ):
3401
+ return x + self .a * self .b
3402
+
3403
+ class M2 (torch .nn .Module ):
3404
+ def forward (self , x ):
3405
+ m = M1 ()
3406
+ return m (x ) * x
3407
+
3408
+ inps = (torch .randn (3 , 3 ),)
3409
+ # Strict export segfaults (Issue #128109)
3410
+ ep = torch .export .export (M2 (), inps , strict = False )
3411
+ self .assertTrue (torch .allclose (ep .module ()(* inps ), M2 ()(* inps )))
3412
+
3413
+ self .assertEqual (len (ep .state_dict ), 0 )
3414
+ self .assertEqual (len (ep .constants ), 1 )
3415
+
3416
+ self .assertExpectedInline (
3417
+ str (ep .graph ).strip (),
3418
+ """\
3419
+ graph():
3420
+ %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
3421
+ %x : [num_users=2] = placeholder[target=x]
3422
+ %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False})
3423
+ %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%ones,), kwargs = {})
3424
+ %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {})
3425
+ %detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {})
3426
+ %detach_2 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_1,), kwargs = {})
3427
+ %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach, %detach_2), kwargs = {})
3428
+ %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {})
3429
+ %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {})
3430
+ return (mul_1,)""" ,
3431
+ )
3432
+
3433
+ unflattened = unflatten (ep )
3434
+ self .assertTrue (torch .allclose (unflattened (* inps ), M2 ()(* inps )))
3435
+
3282
3436
def test_lazy_module_kwargs (self ):
3283
3437
class LazyModule (torch .nn .modules .lazy .LazyModuleMixin , torch .nn .Module ):
3284
3438
def initialize_parameters (self , * args , ** kwargs ):
0 commit comments