@@ -3496,6 +3496,7 @@ def _get_example_tensor_lists(
3496
3496
include_list_of_lists = True ,
3497
3497
include_requires_grad = True ,
3498
3498
include_inner_dim_size_1 = False ,
3499
+ include_2d_tensor = False ,
3499
3500
):
3500
3501
def _make_tensor (
3501
3502
* shape , include_requires_grad = include_requires_grad , requires_grad = True
@@ -3562,6 +3563,16 @@ def _make_tensor(
3562
3563
] # (B, *, 5, 1)
3563
3564
)
3564
3565
3566
+ if include_2d_tensor :
3567
+ example_lists .append (
3568
+ [
3569
+ _make_tensor (2 ),
3570
+ _make_tensor (3 , requires_grad = False ),
3571
+ _make_tensor (4 , requires_grad = False ),
3572
+ _make_tensor (6 ),
3573
+ ] # (B, *)
3574
+ )
3575
+
3565
3576
return example_lists
3566
3577
3567
3578
def test_tensor_attributes (self , device ):
@@ -4137,7 +4148,7 @@ def test_jagged_op_different_output_shape_dim(
4137
4148
@dtypes (torch .float32 )
4138
4149
@parametrize ("requires_grad" , [False , True ])
4139
4150
@parametrize ("components_require_grad" , [False , True ])
4140
- def test_jagged_softmax_dim (
4151
+ def test_softmax_dim (
4141
4152
self ,
4142
4153
device ,
4143
4154
dtype ,
@@ -4210,7 +4221,7 @@ def test_jagged_softmax_dim(
4210
4221
@parametrize ("keepdim" , [False , True ])
4211
4222
@parametrize ("requires_grad" , [False , True ])
4212
4223
@parametrize ("components_require_grad" , [False , True ])
4213
- def test_jagged_op_dim_reduce_ragged_idx_1_different_output_shape (
4224
+ def test_op_dim_reduce_ragged_idx_1_different_output_shape (
4214
4225
self , device , dtype , keepdim , requires_grad , components_require_grad , func
4215
4226
):
4216
4227
"""
@@ -4252,7 +4263,7 @@ def test_jagged_op_dim_reduce_ragged_idx_1_different_output_shape(
4252
4263
@dtypes (torch .float32 )
4253
4264
@parametrize ("requires_grad" , [False , True ])
4254
4265
@parametrize ("components_require_grad" , [False , True ])
4255
- def test_jagged_softmax_dim_reduce_ragged_idx_1_same_output_shape (
4266
+ def test_softmax_dim_reduce_ragged_idx_1 (
4256
4267
self , device , dtype , requires_grad , components_require_grad
4257
4268
):
4258
4269
"""
@@ -4319,6 +4330,133 @@ def test_softmax_reduce_batch_dim(
4319
4330
):
4320
4331
out = torch .nn .functional .softmax (nt , dim = reduce_dim )
4321
4332
4333
+ @dtypes (torch .float32 )
4334
+ @parametrize ("requires_grad" , [False , True ])
4335
+ @parametrize ("components_require_grad" , [False , True ])
4336
+ def test_layer_norm_reduce_ragged_idx_1 (
4337
+ self , device , dtype , requires_grad , components_require_grad
4338
+ ):
4339
+ """
4340
+ Layer normalization on NestedTensor passes when trying to normalize across ragged dimension, where ragged_idx == 1.
4341
+ """
4342
+
4343
+ # requires_grad = False does not currently work with dynamo tests and throws this error:
4344
+ # AssertionError: SymInts must use SymNodeVariable.
4345
+ # If the underlying value is static, we will create a ConstantVariable and specialize.
4346
+ if torch ._dynamo .is_compiling () and not requires_grad :
4347
+ return
4348
+
4349
+ tensor_lists = self ._get_example_tensor_lists (
4350
+ include_list_of_lists = False ,
4351
+ include_requires_grad = components_require_grad ,
4352
+ include_inner_dim_size_1 = True , # (B, *, 1)
4353
+ )
4354
+
4355
+ for tensor_list in tensor_lists :
4356
+ nt = torch .nested .nested_tensor (
4357
+ tensor_list ,
4358
+ device = device ,
4359
+ dtype = dtype ,
4360
+ layout = torch .jagged ,
4361
+ requires_grad = requires_grad ,
4362
+ )
4363
+
4364
+ if (
4365
+ nt .dim () >= 3
4366
+ ): # layer norm only works for tensors with 3 or more dimensions
4367
+ normalized_shape = nt .shape [nt ._ragged_idx :]
4368
+
4369
+ out_actual = torch .nn .functional .layer_norm (
4370
+ nt , normalized_shape = normalized_shape
4371
+ )
4372
+ out_expected = torch .cat (
4373
+ [
4374
+ torch .nn .functional .layer_norm (t , normalized_shape = t .shape )
4375
+ for t in nt .unbind ()
4376
+ ]
4377
+ ) # e.g. in 3D tensor (B, *, M), performs layer normalization on B 2D tensors (*, M)
4378
+
4379
+ self .assertTrue (
4380
+ out_actual .is_nested ,
4381
+ "layer_norm(): the result of reducing a nested tensor along the ragged dimension is a nested tensor" ,
4382
+ ) # output is a nested tensor
4383
+ self .assertEqual (out_actual ._values .shape , out_expected .shape )
4384
+ self .assertTrue (torch .allclose (out_actual .values (), out_expected ))
4385
+
4386
+ @dtypes (torch .float32 )
4387
+ @parametrize ("requires_grad" , [False , True ])
4388
+ @parametrize ("components_require_grad" , [False , True ])
4389
+ def test_layer_norm_2d_input (
4390
+ self ,
4391
+ device ,
4392
+ dtype ,
4393
+ requires_grad ,
4394
+ components_require_grad ,
4395
+ ):
4396
+ """
4397
+ Layer normalization on NestedTensor fails when trying to operate on a 2-dimensional tensor
4398
+ """
4399
+ tensor_lists = self ._get_example_tensor_lists (
4400
+ include_list_of_lists = False ,
4401
+ include_requires_grad = components_require_grad ,
4402
+ include_inner_dim_size_1 = True , # (B, *, 1)
4403
+ include_2d_tensor = True , # (B, *)
4404
+ )
4405
+
4406
+ for tensor_list in tensor_lists :
4407
+ nt = torch .nested .nested_tensor (
4408
+ tensor_list ,
4409
+ device = device ,
4410
+ dtype = dtype ,
4411
+ layout = torch .jagged ,
4412
+ requires_grad = requires_grad ,
4413
+ )
4414
+
4415
+ if nt .dim () <= 2 :
4416
+ with self .assertRaisesRegex (
4417
+ RuntimeError ,
4418
+ "not supported for NestedTensor objects with 2 or fewer dimensions" ,
4419
+ ):
4420
+ out = torch .nn .functional .layer_norm (
4421
+ nt , normalized_shape = (nt .shape [nt ._ragged_idx ],)
4422
+ )
4423
+
4424
+ @dtypes (torch .float32 )
4425
+ @parametrize ("requires_grad" , [False , True ])
4426
+ @parametrize ("components_require_grad" , [False , True ])
4427
+ def test_layer_norm_operate_on_batch_dim (
4428
+ self ,
4429
+ device ,
4430
+ dtype ,
4431
+ requires_grad ,
4432
+ components_require_grad ,
4433
+ ):
4434
+ """
4435
+ Layer normalization on NestedTensor fails when trying to operate on the batch dimension
4436
+ """
4437
+ tensor_lists = self ._get_example_tensor_lists (
4438
+ include_list_of_lists = False ,
4439
+ include_requires_grad = components_require_grad ,
4440
+ include_inner_dim_size_1 = True , # (B, *, 1)
4441
+ include_2d_tensor = True , # (B, *)
4442
+ )
4443
+
4444
+ for tensor_list in tensor_lists :
4445
+ nt = torch .nested .nested_tensor (
4446
+ tensor_list ,
4447
+ device = device ,
4448
+ dtype = dtype ,
4449
+ layout = torch .jagged ,
4450
+ requires_grad = requires_grad ,
4451
+ )
4452
+
4453
+ if nt .dim () > 2 : # cannot perform layer normalization on 2D tensors
4454
+ with self .assertRaisesRegex (
4455
+ RuntimeError ,
4456
+ "not supported when normalizing over the batch dimension for NestedTensor" ,
4457
+ ):
4458
+ out = torch .nn .functional .layer_norm (nt , normalized_shape = nt .shape )
4459
+
4322
4460
@dtypes (torch .float32 )
4323
4461
@parametrize (
4324
4462
"func" ,
@@ -4331,7 +4469,7 @@ def test_softmax_reduce_batch_dim(
4331
4469
@parametrize ("keepdim" , [False , True ])
4332
4470
@parametrize ("requires_grad" , [False , True ])
4333
4471
@parametrize ("components_require_grad" , [False , True ])
4334
- def test_jagged_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape (
4472
+ def test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape (
4335
4473
self ,
4336
4474
device ,
4337
4475
dtype ,
@@ -4391,7 +4529,7 @@ def test_jagged_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape(
4391
4529
) # [transpose consecutive dimensions, transpose nonconsecutive dimensions]
4392
4530
@parametrize ("requires_grad" , [False , True ])
4393
4531
@parametrize ("components_require_grad" , [False , True ])
4394
- def test_jagged_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape (
4532
+ def test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape (
4395
4533
self ,
4396
4534
device ,
4397
4535
dtype ,
@@ -4439,7 +4577,7 @@ def test_jagged_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape(
4439
4577
@parametrize ("keepdim" , [False , True ])
4440
4578
@parametrize ("requires_grad" , [False , True ])
4441
4579
@parametrize ("components_require_grad" , [False , True ])
4442
- def test_jagged_op_dim_transpose_non_ragged_dim_different_output_shape (
4580
+ def test_op_dim_transpose_non_ragged_dim_different_output_shape (
4443
4581
self , device , dtype , keepdim , requires_grad , components_require_grad , func
4444
4582
):
4445
4583
"""
@@ -4508,7 +4646,7 @@ def test_jagged_op_dim_transpose_non_ragged_dim_different_output_shape(
4508
4646
@dtypes (torch .float32 )
4509
4647
@parametrize ("requires_grad" , [False , True ])
4510
4648
@parametrize ("components_require_grad" , [False , True ])
4511
- def test_jagged_softmax_dim_transpose_non_ragged_dim (
4649
+ def test_softmax_dim_transpose_non_ragged_dim (
4512
4650
self ,
4513
4651
device ,
4514
4652
dtype ,
@@ -4560,7 +4698,7 @@ def test_jagged_softmax_dim_transpose_non_ragged_dim(
4560
4698
@parametrize ("keepdim" , [False , True ])
4561
4699
@parametrize ("requires_grad" , [False , True ])
4562
4700
@parametrize ("components_require_grad" , [False , True ])
4563
- def test_jagged_sum_dim_reduce_ragged_and_non_batch (
4701
+ def test_sum_dim_reduce_ragged_and_non_batch (
4564
4702
self ,
4565
4703
device ,
4566
4704
dtype ,
@@ -4599,7 +4737,7 @@ def test_jagged_sum_dim_reduce_ragged_and_non_batch(
4599
4737
@parametrize ("keepdim" , [False , True ])
4600
4738
@parametrize ("requires_grad" , [False , True ])
4601
4739
@parametrize ("components_require_grad" , [False , True ])
4602
- def test_jagged_sum_dim_reduce_batch_and_non_batch (
4740
+ def test_sum_dim_reduce_batch_and_non_batch (
4603
4741
self ,
4604
4742
device ,
4605
4743
dtype ,
@@ -4643,7 +4781,7 @@ def test_jagged_sum_dim_reduce_batch_and_non_batch(
4643
4781
@parametrize ("keepdim" , [False , True ])
4644
4782
@parametrize ("requires_grad" , [False , True ])
4645
4783
@parametrize ("components_require_grad" , [False , True ])
4646
- def test_jagged_op_dim_reduce_batch_only_different_output_shape (
4784
+ def test_op_dim_reduce_batch_only_different_output_shape (
4647
4785
self , device , dtype , keepdim , requires_grad , components_require_grad , func
4648
4786
):
4649
4787
"""
@@ -4681,7 +4819,7 @@ def test_jagged_op_dim_reduce_batch_only_different_output_shape(
4681
4819
@parametrize ("keepdim" , [False , True ])
4682
4820
@parametrize ("requires_grad" , [False , True ])
4683
4821
@parametrize ("components_require_grad" , [False , True ])
4684
- def test_jagged_op_dim_with_lengths_different_output_shape (
4822
+ def test_op_dim_with_lengths_different_output_shape (
4685
4823
self ,
4686
4824
device ,
4687
4825
dtype ,
@@ -4736,7 +4874,7 @@ def test_jagged_op_dim_with_lengths_different_output_shape(
4736
4874
@dtypes (torch .float32 )
4737
4875
@parametrize ("requires_grad" , [False , True ])
4738
4876
@parametrize ("components_require_grad" , [False , True ])
4739
- def test_jagged_softmax_dim_with_lengths (
4877
+ def test_softmax_dim_with_lengths (
4740
4878
self ,
4741
4879
device ,
4742
4880
dtype ,
@@ -4782,11 +4920,69 @@ def test_jagged_softmax_dim_with_lengths(
4782
4920
else :
4783
4921
out = torch .nn .functional .softmax (nt_with_holes , dim = reduce_dim )
4784
4922
4923
+ @skipIfTorchDynamo (
4924
+ "ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx] does not currently work "
4925
+ + "with dynamo tests and throws this error: `AssertionError: SymInts must use SymNodeVariable. "
4926
+ + "If the underlying value is static, we will create a ConstantVariable and specialize.`"
4927
+ )
4928
+ @dtypes (torch .float32 )
4929
+ @parametrize ("requires_grad" , [False , True ])
4930
+ @parametrize ("components_require_grad" , [False , True ])
4931
+ def test_layer_norm_with_lengths (
4932
+ self ,
4933
+ device ,
4934
+ dtype ,
4935
+ requires_grad ,
4936
+ components_require_grad ,
4937
+ ):
4938
+ """
4939
+ Layer normalization on NestedTensor fails when trying to operate on a nested tensor with lengths,
4940
+ i.e. a nested tensor with holes, if operating on the ragged dimension.
4941
+ """
4942
+
4943
+ # create components for nested tensor
4944
+ lengths = torch .randint (5 , 10 , (20 ,), device = device )
4945
+ offsets = torch .zeros ((21 ,), device = device , dtype = torch .int )
4946
+ torch .cumsum (lengths , dim = 0 , out = offsets [1 :])
4947
+ values = torch .randn (
4948
+ (offsets [- 1 ].item (), 10 , 30 ),
4949
+ device = device ,
4950
+ dtype = dtype ,
4951
+ requires_grad = requires_grad ,
4952
+ )
4953
+
4954
+ nt_with_holes = torch .nested .nested_tensor_from_jagged (
4955
+ values ,
4956
+ offsets ,
4957
+ lengths = offsets .diff () - 2 , # arbitrary subtraction to create holes
4958
+ )
4959
+
4960
+ ragged_size = nt_with_holes .shape [nt_with_holes ._ragged_idx ]
4961
+
4962
+ normalized_shapes = (
4963
+ (10 , 30 ), # normalization on non-ragged dimension passes
4964
+ (ragged_size , 10 , 30 ), # normalization on ragged dimension fails
4965
+ )
4966
+
4967
+ for normalized_shape in normalized_shapes :
4968
+ if ragged_size in normalized_shape :
4969
+ with self .assertRaisesRegex (
4970
+ RuntimeError ,
4971
+ "not supported where lengths is not None if operating on the ragged dimension for NestedTensor" ,
4972
+ ):
4973
+ out = torch .nn .functional .layer_norm (
4974
+ nt_with_holes , normalized_shape = normalized_shape
4975
+ )
4976
+ else :
4977
+ out = torch .nn .functional .layer_norm (
4978
+ nt_with_holes , normalized_shape = normalized_shape
4979
+ )
4980
+
4785
4981
@dtypes (torch .float32 )
4786
4982
@parametrize ("keepdim" , [True ])
4787
4983
@parametrize ("requires_grad" , [False , True ])
4788
4984
@parametrize ("components_require_grad" , [False , True ])
4789
- def test_jagged_mean_dim_reduce_multiple_dims (
4985
+ def test_mean_dim_reduce_multiple_dims (
4790
4986
self ,
4791
4987
device ,
4792
4988
dtype ,
@@ -4826,7 +5022,7 @@ def test_jagged_mean_dim_reduce_multiple_dims(
4826
5022
@parametrize ("keepdim" , [False , True ])
4827
5023
@parametrize ("requires_grad" , [False , True ])
4828
5024
@parametrize ("components_require_grad" , [False , True ])
4829
- def test_jagged_mean_dim_keepdim_False (
5025
+ def test_mean_dim_keepdim_False (
4830
5026
self ,
4831
5027
device ,
4832
5028
dtype ,
@@ -5548,29 +5744,6 @@ def test_unbind_lengths_ragged_idx_0(self, device):
5548
5744
lambda : nt .unbind (),
5549
5745
)
5550
5746
5551
- @xfailIfTorchDynamo
5552
- def test_layer_norm_2 (self , device ):
5553
- test_tensor_list = self ._get_list_for_jagged_tensor (
5554
- ((2 , 3 , 4 ), 3 ), device = device , requires_grad = True
5555
- )
5556
- bias = torch .randn (3 , requires_grad = False , dtype = torch .float64 , device = device )
5557
-
5558
- def grad_test_func (a , b , c , bias ):
5559
- nt = torch .nested .as_nested_tensor ([a , b , c ], layout = torch .jagged )
5560
- out = torch .nn .functional .layer_norm (nt , (nt .shape [- 1 ],), bias = bias )
5561
- return out .values ()
5562
-
5563
- gradcheck (
5564
- grad_test_func , inputs = (* test_tensor_list , bias ), check_batched_grad = False
5565
- )
5566
-
5567
- with self .assertRaisesRegex (
5568
- RuntimeError ,
5569
- r"layer_norm\(\): normalizing over ragged dim not supported for nested tensors" ,
5570
- ):
5571
- nt = torch .nested .as_nested_tensor (test_tensor_list , layout = torch .jagged )
5572
- _ = torch .nn .functional .layer_norm (nt , (nt .shape [- 2 ], nt .shape [- 1 ]))
5573
-
5574
5747
def test_narrow (self , device ):
5575
5748
starts = torch .tensor ([0 , 1 , 2 , 3 , 4 ], device = device , dtype = torch .int64 )
5576
5749
lengths = torch .tensor ([3 , 2 , 2 , 1 , 5 ], device = device , dtype = torch .int64 )
0 commit comments