Skip to content

Commit 46994e7

Browse files
jananisrirampytorchmergebot
authored andcommitted
[NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (pytorch#132172)
Modify the existing `layer normalization` operator in PyTorch, invoked by `torch.layer_norm`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the `aten` padding operator, enables PyTorch users to invoke `torch.nn.functional.layer_norm` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` or `(B, *, M, N)` nested tensor. Write unit tests based on the `softmax` jagged operator to verify the accuracy of the ragged reduction implementation for `torch.nn.functional.layer_norm`. Add unit tests to verify error handling for unsupported features. Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. The layer normalization operator also requires an operation on a 2-dimensional layer; for nested tensors with 4 or more dimensions, I flatten the extra dimensions, then unflatten them after performing layer normalization. Pull Request resolved: pytorch#132172 Approved by: https://github.com/davidberard98 ghstack dependencies: pytorch#132170
1 parent 89053e3 commit 46994e7

File tree

2 files changed

+292
-40
lines changed

2 files changed

+292
-40
lines changed

test/test_nestedtensor.py

+210-37
Original file line numberDiff line numberDiff line change
@@ -3496,6 +3496,7 @@ def _get_example_tensor_lists(
34963496
include_list_of_lists=True,
34973497
include_requires_grad=True,
34983498
include_inner_dim_size_1=False,
3499+
include_2d_tensor=False,
34993500
):
35003501
def _make_tensor(
35013502
*shape, include_requires_grad=include_requires_grad, requires_grad=True
@@ -3562,6 +3563,16 @@ def _make_tensor(
35623563
] # (B, *, 5, 1)
35633564
)
35643565

3566+
if include_2d_tensor:
3567+
example_lists.append(
3568+
[
3569+
_make_tensor(2),
3570+
_make_tensor(3, requires_grad=False),
3571+
_make_tensor(4, requires_grad=False),
3572+
_make_tensor(6),
3573+
] # (B, *)
3574+
)
3575+
35653576
return example_lists
35663577

35673578
def test_tensor_attributes(self, device):
@@ -4137,7 +4148,7 @@ def test_jagged_op_different_output_shape_dim(
41374148
@dtypes(torch.float32)
41384149
@parametrize("requires_grad", [False, True])
41394150
@parametrize("components_require_grad", [False, True])
4140-
def test_jagged_softmax_dim(
4151+
def test_softmax_dim(
41414152
self,
41424153
device,
41434154
dtype,
@@ -4210,7 +4221,7 @@ def test_jagged_softmax_dim(
42104221
@parametrize("keepdim", [False, True])
42114222
@parametrize("requires_grad", [False, True])
42124223
@parametrize("components_require_grad", [False, True])
4213-
def test_jagged_op_dim_reduce_ragged_idx_1_different_output_shape(
4224+
def test_op_dim_reduce_ragged_idx_1_different_output_shape(
42144225
self, device, dtype, keepdim, requires_grad, components_require_grad, func
42154226
):
42164227
"""
@@ -4252,7 +4263,7 @@ def test_jagged_op_dim_reduce_ragged_idx_1_different_output_shape(
42524263
@dtypes(torch.float32)
42534264
@parametrize("requires_grad", [False, True])
42544265
@parametrize("components_require_grad", [False, True])
4255-
def test_jagged_softmax_dim_reduce_ragged_idx_1_same_output_shape(
4266+
def test_softmax_dim_reduce_ragged_idx_1(
42564267
self, device, dtype, requires_grad, components_require_grad
42574268
):
42584269
"""
@@ -4319,6 +4330,133 @@ def test_softmax_reduce_batch_dim(
43194330
):
43204331
out = torch.nn.functional.softmax(nt, dim=reduce_dim)
43214332

4333+
@dtypes(torch.float32)
4334+
@parametrize("requires_grad", [False, True])
4335+
@parametrize("components_require_grad", [False, True])
4336+
def test_layer_norm_reduce_ragged_idx_1(
4337+
self, device, dtype, requires_grad, components_require_grad
4338+
):
4339+
"""
4340+
Layer normalization on NestedTensor passes when trying to normalize across ragged dimension, where ragged_idx == 1.
4341+
"""
4342+
4343+
# requires_grad = False does not currently work with dynamo tests and throws this error:
4344+
# AssertionError: SymInts must use SymNodeVariable.
4345+
# If the underlying value is static, we will create a ConstantVariable and specialize.
4346+
if torch._dynamo.is_compiling() and not requires_grad:
4347+
return
4348+
4349+
tensor_lists = self._get_example_tensor_lists(
4350+
include_list_of_lists=False,
4351+
include_requires_grad=components_require_grad,
4352+
include_inner_dim_size_1=True, # (B, *, 1)
4353+
)
4354+
4355+
for tensor_list in tensor_lists:
4356+
nt = torch.nested.nested_tensor(
4357+
tensor_list,
4358+
device=device,
4359+
dtype=dtype,
4360+
layout=torch.jagged,
4361+
requires_grad=requires_grad,
4362+
)
4363+
4364+
if (
4365+
nt.dim() >= 3
4366+
): # layer norm only works for tensors with 3 or more dimensions
4367+
normalized_shape = nt.shape[nt._ragged_idx :]
4368+
4369+
out_actual = torch.nn.functional.layer_norm(
4370+
nt, normalized_shape=normalized_shape
4371+
)
4372+
out_expected = torch.cat(
4373+
[
4374+
torch.nn.functional.layer_norm(t, normalized_shape=t.shape)
4375+
for t in nt.unbind()
4376+
]
4377+
) # e.g. in 3D tensor (B, *, M), performs layer normalization on B 2D tensors (*, M)
4378+
4379+
self.assertTrue(
4380+
out_actual.is_nested,
4381+
"layer_norm(): the result of reducing a nested tensor along the ragged dimension is a nested tensor",
4382+
) # output is a nested tensor
4383+
self.assertEqual(out_actual._values.shape, out_expected.shape)
4384+
self.assertTrue(torch.allclose(out_actual.values(), out_expected))
4385+
4386+
@dtypes(torch.float32)
4387+
@parametrize("requires_grad", [False, True])
4388+
@parametrize("components_require_grad", [False, True])
4389+
def test_layer_norm_2d_input(
4390+
self,
4391+
device,
4392+
dtype,
4393+
requires_grad,
4394+
components_require_grad,
4395+
):
4396+
"""
4397+
Layer normalization on NestedTensor fails when trying to operate on a 2-dimensional tensor
4398+
"""
4399+
tensor_lists = self._get_example_tensor_lists(
4400+
include_list_of_lists=False,
4401+
include_requires_grad=components_require_grad,
4402+
include_inner_dim_size_1=True, # (B, *, 1)
4403+
include_2d_tensor=True, # (B, *)
4404+
)
4405+
4406+
for tensor_list in tensor_lists:
4407+
nt = torch.nested.nested_tensor(
4408+
tensor_list,
4409+
device=device,
4410+
dtype=dtype,
4411+
layout=torch.jagged,
4412+
requires_grad=requires_grad,
4413+
)
4414+
4415+
if nt.dim() <= 2:
4416+
with self.assertRaisesRegex(
4417+
RuntimeError,
4418+
"not supported for NestedTensor objects with 2 or fewer dimensions",
4419+
):
4420+
out = torch.nn.functional.layer_norm(
4421+
nt, normalized_shape=(nt.shape[nt._ragged_idx],)
4422+
)
4423+
4424+
@dtypes(torch.float32)
4425+
@parametrize("requires_grad", [False, True])
4426+
@parametrize("components_require_grad", [False, True])
4427+
def test_layer_norm_operate_on_batch_dim(
4428+
self,
4429+
device,
4430+
dtype,
4431+
requires_grad,
4432+
components_require_grad,
4433+
):
4434+
"""
4435+
Layer normalization on NestedTensor fails when trying to operate on the batch dimension
4436+
"""
4437+
tensor_lists = self._get_example_tensor_lists(
4438+
include_list_of_lists=False,
4439+
include_requires_grad=components_require_grad,
4440+
include_inner_dim_size_1=True, # (B, *, 1)
4441+
include_2d_tensor=True, # (B, *)
4442+
)
4443+
4444+
for tensor_list in tensor_lists:
4445+
nt = torch.nested.nested_tensor(
4446+
tensor_list,
4447+
device=device,
4448+
dtype=dtype,
4449+
layout=torch.jagged,
4450+
requires_grad=requires_grad,
4451+
)
4452+
4453+
if nt.dim() > 2: # cannot perform layer normalization on 2D tensors
4454+
with self.assertRaisesRegex(
4455+
RuntimeError,
4456+
"not supported when normalizing over the batch dimension for NestedTensor",
4457+
):
4458+
out = torch.nn.functional.layer_norm(nt, normalized_shape=nt.shape)
4459+
43224460
@dtypes(torch.float32)
43234461
@parametrize(
43244462
"func",
@@ -4331,7 +4469,7 @@ def test_softmax_reduce_batch_dim(
43314469
@parametrize("keepdim", [False, True])
43324470
@parametrize("requires_grad", [False, True])
43334471
@parametrize("components_require_grad", [False, True])
4334-
def test_jagged_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape(
4472+
def test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape(
43354473
self,
43364474
device,
43374475
dtype,
@@ -4391,7 +4529,7 @@ def test_jagged_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape(
43914529
) # [transpose consecutive dimensions, transpose nonconsecutive dimensions]
43924530
@parametrize("requires_grad", [False, True])
43934531
@parametrize("components_require_grad", [False, True])
4394-
def test_jagged_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape(
4532+
def test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape(
43954533
self,
43964534
device,
43974535
dtype,
@@ -4439,7 +4577,7 @@ def test_jagged_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape(
44394577
@parametrize("keepdim", [False, True])
44404578
@parametrize("requires_grad", [False, True])
44414579
@parametrize("components_require_grad", [False, True])
4442-
def test_jagged_op_dim_transpose_non_ragged_dim_different_output_shape(
4580+
def test_op_dim_transpose_non_ragged_dim_different_output_shape(
44434581
self, device, dtype, keepdim, requires_grad, components_require_grad, func
44444582
):
44454583
"""
@@ -4508,7 +4646,7 @@ def test_jagged_op_dim_transpose_non_ragged_dim_different_output_shape(
45084646
@dtypes(torch.float32)
45094647
@parametrize("requires_grad", [False, True])
45104648
@parametrize("components_require_grad", [False, True])
4511-
def test_jagged_softmax_dim_transpose_non_ragged_dim(
4649+
def test_softmax_dim_transpose_non_ragged_dim(
45124650
self,
45134651
device,
45144652
dtype,
@@ -4560,7 +4698,7 @@ def test_jagged_softmax_dim_transpose_non_ragged_dim(
45604698
@parametrize("keepdim", [False, True])
45614699
@parametrize("requires_grad", [False, True])
45624700
@parametrize("components_require_grad", [False, True])
4563-
def test_jagged_sum_dim_reduce_ragged_and_non_batch(
4701+
def test_sum_dim_reduce_ragged_and_non_batch(
45644702
self,
45654703
device,
45664704
dtype,
@@ -4599,7 +4737,7 @@ def test_jagged_sum_dim_reduce_ragged_and_non_batch(
45994737
@parametrize("keepdim", [False, True])
46004738
@parametrize("requires_grad", [False, True])
46014739
@parametrize("components_require_grad", [False, True])
4602-
def test_jagged_sum_dim_reduce_batch_and_non_batch(
4740+
def test_sum_dim_reduce_batch_and_non_batch(
46034741
self,
46044742
device,
46054743
dtype,
@@ -4643,7 +4781,7 @@ def test_jagged_sum_dim_reduce_batch_and_non_batch(
46434781
@parametrize("keepdim", [False, True])
46444782
@parametrize("requires_grad", [False, True])
46454783
@parametrize("components_require_grad", [False, True])
4646-
def test_jagged_op_dim_reduce_batch_only_different_output_shape(
4784+
def test_op_dim_reduce_batch_only_different_output_shape(
46474785
self, device, dtype, keepdim, requires_grad, components_require_grad, func
46484786
):
46494787
"""
@@ -4681,7 +4819,7 @@ def test_jagged_op_dim_reduce_batch_only_different_output_shape(
46814819
@parametrize("keepdim", [False, True])
46824820
@parametrize("requires_grad", [False, True])
46834821
@parametrize("components_require_grad", [False, True])
4684-
def test_jagged_op_dim_with_lengths_different_output_shape(
4822+
def test_op_dim_with_lengths_different_output_shape(
46854823
self,
46864824
device,
46874825
dtype,
@@ -4736,7 +4874,7 @@ def test_jagged_op_dim_with_lengths_different_output_shape(
47364874
@dtypes(torch.float32)
47374875
@parametrize("requires_grad", [False, True])
47384876
@parametrize("components_require_grad", [False, True])
4739-
def test_jagged_softmax_dim_with_lengths(
4877+
def test_softmax_dim_with_lengths(
47404878
self,
47414879
device,
47424880
dtype,
@@ -4782,11 +4920,69 @@ def test_jagged_softmax_dim_with_lengths(
47824920
else:
47834921
out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim)
47844922

4923+
@skipIfTorchDynamo(
4924+
"ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx] does not currently work "
4925+
+ "with dynamo tests and throws this error: `AssertionError: SymInts must use SymNodeVariable. "
4926+
+ "If the underlying value is static, we will create a ConstantVariable and specialize.`"
4927+
)
4928+
@dtypes(torch.float32)
4929+
@parametrize("requires_grad", [False, True])
4930+
@parametrize("components_require_grad", [False, True])
4931+
def test_layer_norm_with_lengths(
4932+
self,
4933+
device,
4934+
dtype,
4935+
requires_grad,
4936+
components_require_grad,
4937+
):
4938+
"""
4939+
Layer normalization on NestedTensor fails when trying to operate on a nested tensor with lengths,
4940+
i.e. a nested tensor with holes, if operating on the ragged dimension.
4941+
"""
4942+
4943+
# create components for nested tensor
4944+
lengths = torch.randint(5, 10, (20,), device=device)
4945+
offsets = torch.zeros((21,), device=device, dtype=torch.int)
4946+
torch.cumsum(lengths, dim=0, out=offsets[1:])
4947+
values = torch.randn(
4948+
(offsets[-1].item(), 10, 30),
4949+
device=device,
4950+
dtype=dtype,
4951+
requires_grad=requires_grad,
4952+
)
4953+
4954+
nt_with_holes = torch.nested.nested_tensor_from_jagged(
4955+
values,
4956+
offsets,
4957+
lengths=offsets.diff() - 2, # arbitrary subtraction to create holes
4958+
)
4959+
4960+
ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx]
4961+
4962+
normalized_shapes = (
4963+
(10, 30), # normalization on non-ragged dimension passes
4964+
(ragged_size, 10, 30), # normalization on ragged dimension fails
4965+
)
4966+
4967+
for normalized_shape in normalized_shapes:
4968+
if ragged_size in normalized_shape:
4969+
with self.assertRaisesRegex(
4970+
RuntimeError,
4971+
"not supported where lengths is not None if operating on the ragged dimension for NestedTensor",
4972+
):
4973+
out = torch.nn.functional.layer_norm(
4974+
nt_with_holes, normalized_shape=normalized_shape
4975+
)
4976+
else:
4977+
out = torch.nn.functional.layer_norm(
4978+
nt_with_holes, normalized_shape=normalized_shape
4979+
)
4980+
47854981
@dtypes(torch.float32)
47864982
@parametrize("keepdim", [True])
47874983
@parametrize("requires_grad", [False, True])
47884984
@parametrize("components_require_grad", [False, True])
4789-
def test_jagged_mean_dim_reduce_multiple_dims(
4985+
def test_mean_dim_reduce_multiple_dims(
47904986
self,
47914987
device,
47924988
dtype,
@@ -4826,7 +5022,7 @@ def test_jagged_mean_dim_reduce_multiple_dims(
48265022
@parametrize("keepdim", [False, True])
48275023
@parametrize("requires_grad", [False, True])
48285024
@parametrize("components_require_grad", [False, True])
4829-
def test_jagged_mean_dim_keepdim_False(
5025+
def test_mean_dim_keepdim_False(
48305026
self,
48315027
device,
48325028
dtype,
@@ -5548,29 +5744,6 @@ def test_unbind_lengths_ragged_idx_0(self, device):
55485744
lambda: nt.unbind(),
55495745
)
55505746

5551-
@xfailIfTorchDynamo
5552-
def test_layer_norm_2(self, device):
5553-
test_tensor_list = self._get_list_for_jagged_tensor(
5554-
((2, 3, 4), 3), device=device, requires_grad=True
5555-
)
5556-
bias = torch.randn(3, requires_grad=False, dtype=torch.float64, device=device)
5557-
5558-
def grad_test_func(a, b, c, bias):
5559-
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
5560-
out = torch.nn.functional.layer_norm(nt, (nt.shape[-1],), bias=bias)
5561-
return out.values()
5562-
5563-
gradcheck(
5564-
grad_test_func, inputs=(*test_tensor_list, bias), check_batched_grad=False
5565-
)
5566-
5567-
with self.assertRaisesRegex(
5568-
RuntimeError,
5569-
r"layer_norm\(\): normalizing over ragged dim not supported for nested tensors",
5570-
):
5571-
nt = torch.nested.as_nested_tensor(test_tensor_list, layout=torch.jagged)
5572-
_ = torch.nn.functional.layer_norm(nt, (nt.shape[-2], nt.shape[-1]))
5573-
55745747
def test_narrow(self, device):
55755748
starts = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
55765749
lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)

0 commit comments

Comments
 (0)