Skip to content

Commit afaee00

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Add python nested_tensor and as_nested_tensor constructors in torch.nested (pytorch#85593)
Remove `torch.nested_tensor` which has erroneous behavior wrt gradients (could be either leaf or not leaf). Introduce `torch.nested.nested_tensor` and `torch.nested.as_nested_tensor` in the vein of `torch.tensor` and `torch.as_tensor`. Done in nested `__init__.py` for now but can move to pybind in future (when we want to load from numpy/nested lists ). Discussed offline with @cpuhrsch and pybind constructor (pytorch#85536) was more gnarly than expected, so we can move to that when we do need loading from numpy etc. Differential Revision: [D39806622](https://our.internmc.facebook.com/intern/diff/D39806622) Pull Request resolved: pytorch#85593 Approved by: https://github.com/drisspg, https://github.com/cpuhrsch
1 parent a876432 commit afaee00

File tree

17 files changed

+381
-225
lines changed

17 files changed

+381
-225
lines changed

aten/src/ATen/native/native_functions.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -12716,11 +12716,11 @@
1271612716
variants: function
1271712717
python_module: nn
1271812718

12719-
- func: nested_tensor(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
12719+
- func: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
1272012720
variants: function
1272112721
dispatch:
12722-
CompositeExplicitAutograd: nested_tensor
12723-
autogen: nested_tensor.out
12722+
CompositeExplicitAutograd: _nested_tensor_from_tensor_list
12723+
autogen: _nested_tensor_from_tensor_list.out
1272412724

1272512725
- func: _fw_primal_copy(Tensor self, int level) -> Tensor
1272612726
variants: function

aten/src/ATen/native/nested/NestedTensorMath.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ bool NestedTensor_nested_tensor_from_mask_left_aligned(const Tensor& t, const Te
161161
return sizes.equal(nums);
162162
}
163163

164-
Tensor nested_tensor(
164+
Tensor _nested_tensor_from_tensor_list(
165165
TensorList list,
166166
c10::optional<ScalarType> dtype,
167167
c10::optional<Layout> layout,

docs/source/nested.rst

+9-7
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Construction is straightforward and involves passing a list of Tensors to the co
2323
tensor([0, 1, 2])
2424
>>> b
2525
tensor([3, 4, 5, 6, 7])
26-
>>> nt = torch.nested_tensor([a, b])
26+
>>> nt = torch.nested.nested_tensor([a, b])
2727
>>> nt
2828
nested_tensor([
2929
tensor([0, 1, 2]),
@@ -32,7 +32,7 @@ nested_tensor([
3232

3333
Data type and device can be chosen via the usual keyword arguments.
3434

35-
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
35+
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32, device="cuda")
3636
>>> nt
3737
nested_tensor([
3838
tensor([0., 1., 2.], device='cuda:0'),
@@ -43,15 +43,15 @@ In order to form a valid NestedTensor the passed Tensors also all need to match
4343

4444
>>> a = torch.randn(3, 50, 70) # image 1
4545
>>> b = torch.randn(3, 128, 64) # image 2
46-
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
46+
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
4747
>>> nt.dim()
4848
4
4949

5050
If one of the dimensions don't match, the constructor throws an error.
5151

5252
>>> a = torch.randn(50, 128) # text 1
5353
>>> b = torch.randn(3, 128, 64) # image 2
54-
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
54+
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
5555
Traceback (most recent call last):
5656
File "<stdin>", line 1, in <module>
5757
RuntimeError: All Tensors given to nested_tensor must have the same dimension. Found dimension 3 for Tensor at index 1 and dimension 2 for Tensor at index 0.
@@ -73,7 +73,7 @@ Even though a NestedTensor does not support .size() (or .shape), it supports .si
7373

7474
>>> a = torch.randn(50, 128) # text 1
7575
>>> b = torch.randn(32, 128) # text 2
76-
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
76+
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
7777
>>> nt.size(0)
7878
2
7979
>>> nt.size(1)
@@ -86,7 +86,7 @@ RuntimeError: Given dimension 1 is irregular and does not have a size.
8686
If all dimensions are regular, the NestedTensor is intended to be semantically indistinguishable from a regular torch.Tensor.
8787

8888
>>> a = torch.randn(20, 128) # text 1
89-
>>> nt = torch.nested_tensor([a, a], dtype=torch.float32)
89+
>>> nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)
9090
>>> nt.size(0)
9191
2
9292
>>> nt.size(1)
@@ -112,7 +112,7 @@ unbind allows you to retrieve a view of the constituents.
112112
>>> import torch
113113
>>> a = torch.randn(2, 3)
114114
>>> b = torch.randn(3, 4)
115-
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
115+
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
116116
>>> nt
117117
nested_tensor([
118118
tensor([[ 1.2286, -1.2343, -1.4842],
@@ -149,4 +149,6 @@ The following functions are related to nested tensors:
149149

150150
.. currentmodule:: torch.nested
151151

152+
.. autofunction:: nested_tensor
153+
.. autofunction:: as_nested_tensor
152154
.. autofunction:: to_padded_tensor

test/cpp/api/nested.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
TEST(NestedTest, Nested) {
1111
auto a = torch::randn({2, 3});
1212
auto b = torch::randn({4, 5});
13-
auto nt = torch::nested_tensor({a, b});
13+
auto nt = torch::nested::nested_tensor({a, b});
1414
torch::nested::to_padded_tensor(nt, 0);
1515
}

test/forward_backward_compatibility/check_forward_backward_compatibility.py

+2
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,8 @@
284284
("c10d::allgather_", datetime.date(2022, 10, 1)),
285285
("aten::to_padded_tensor", datetime.date(2022, 10, 1)),
286286
("aten::nested_to_padded_tensor", datetime.date(2022, 10, 1)),
287+
("aten::nested_tensor", datetime.date(2022, 10, 15)),
288+
287289
]
288290

289291
ALLOW_LIST_COMPILED = [

test/profiler/test_profiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1248,7 +1248,7 @@ def test_nested_tensor_with_shapes(self):
12481248
a = torch.randn(4, 4)
12491249
b = torch.randn(4, 4)
12501250
c = torch.randn(4, 4)
1251-
inp = torch.nested_tensor([a, b])
1251+
inp = torch.nested.nested_tensor([a, b])
12521252
with torch.profiler.profile(record_shapes=True) as prof:
12531253
torch.nn.functional.linear(inp, c, None)
12541254
for e in prof.events():

test/test_autograd.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -3569,12 +3569,12 @@ def test_calculate_shape_util(self):
35693569
assert out_shape == torch.Size([10, 5])
35703570
assert grad_shape == torch.Size([5, 10])
35713571

3572-
out = torch.nested_tensor([
3572+
out = torch.nested.as_nested_tensor([
35733573
torch.randn(10, 5, requires_grad=True),
35743574
torch.randn(10, 5, requires_grad=True),
35753575
torch.randn(10, 5, requires_grad=True)]
35763576
)
3577-
grad = torch.nested_tensor([torch.randn(5, 10, requires_grad=True), torch.randn(5, 10, requires_grad=True)])
3577+
grad = torch.nested.as_nested_tensor([torch.randn(5, 10, requires_grad=True), torch.randn(5, 10, requires_grad=True)])
35783578
out_shape, grad_shape = _calculate_shape(out, grad, False)
35793579

35803580
assert torch.equal(out_shape, torch.tensor([[10, 5], [10, 5], [10, 5]]))
@@ -9178,12 +9178,12 @@ def test_autograd_multiple_dispatch_registrations(self, device):
91789178
# test registered AutogradNestedTensor formula
91799179
a = torch.arange(6, dtype=torch.float, device=device).reshape(2, 3).requires_grad_(True)
91809180
b = torch.arange(8, dtype=torch.float, device=device).reshape(2, 4).requires_grad_(True)
9181-
nt = torch.nested_tensor([a, b], dtype=torch.float, device=device)
9181+
nt = torch.nested.as_nested_tensor([a, b], dtype=torch.float, device=device)
91829182

91839183
nt_out = torch._test_autograd_multiple_dispatch(nt)
91849184
c = torch.randn(2, 3, device=device)
91859185
d = torch.randn(2, 4, device=device)
9186-
nt_grad = torch.nested_tensor([c, d], dtype=torch.float, device=device)
9186+
nt_grad = torch.nested.nested_tensor([c, d], dtype=torch.float, device=device)
91879187
nt_out.backward(nt_grad)
91889188

91899189
# bogus gradient for AutogradNestedTensor is grad * grad
@@ -9204,12 +9204,12 @@ def test_autograd_composite_implicit_and_dispatch_registration(self, device):
92049204
# test registered AutogradNestedTensor formula
92059205
a = torch.arange(6, dtype=torch.float, device=device).reshape(2, 3).requires_grad_(True)
92069206
b = torch.arange(8, dtype=torch.float, device=device).reshape(2, 4).requires_grad_(True)
9207-
nt = torch.nested_tensor([a, b], dtype=torch.float, device=device)
9207+
nt = torch.nested.as_nested_tensor([a, b], dtype=torch.float, device=device)
92089208

92099209
nt_out = torch._test_autograd_multiple_dispatch(nt, True)
92109210
c = torch.randn(2, 3, device=device)
92119211
d = torch.randn(2, 4, device=device)
9212-
nt_grad = torch.nested_tensor([c, d], dtype=torch.float, device=device)
9212+
nt_grad = torch.nested.nested_tensor([c, d], dtype=torch.float, device=device)
92139213
nt_out.backward(nt_grad)
92149214

92159215
# bogus gradient for AutogradNestedTensor is grad * grad + grad
@@ -9274,9 +9274,9 @@ def foo(x):
92749274
foo(inp).backward()
92759275

92769276
# sum's input is saved for Nested Tensors
9277-
nt = torch.nested_tensor([torch.rand(2), torch.rand(2)], device=device).requires_grad_()
9277+
nt = torch.nested.nested_tensor([torch.rand(2), torch.rand(2)], device=device, requires_grad=True)
92789278
with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
9279-
foo(nt).backward(torch.nested_tensor([torch.rand(1), torch.rand(1)], device=device))
9279+
foo(nt).backward(torch.nested.nested_tensor([torch.rand(1), torch.rand(1)], device=device))
92809280

92819281
# Import test cases from below autograd/ here. These are found
92829282
# implicitly by the loader, so Flake8 thinks they are unused, hence

test/test_native_mha.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _test_transform_bias_rescale_qkv_impl(
3939
xs = list(torch.unbind(x))
4040
if use_padding:
4141
xs[0] = xs[0][:-1]
42-
x = torch.nested_tensor(xs, device=device, dtype=dtype)
42+
x = torch.nested.nested_tensor(xs, device=device, dtype=dtype)
4343
qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
4444

4545
# We have to use inference_mode here because q/k/v are
@@ -199,15 +199,15 @@ def forward(self, q, k, v, key_padding_mask):
199199
qs = [x[:-1] for x in qs]
200200
else:
201201
qs[0] = qs[0][:-1]
202-
q = torch.nested_tensor(qs, device=device, dtype=dtype)
202+
q = torch.nested.nested_tensor(qs, device=device, dtype=dtype)
203203
if mode == "self":
204204
k = v = q
205205
elif mode == "encdec":
206-
k = torch.nested_tensor(torch.unbind(k), device=device, dtype=dtype)
206+
k = torch.nested.nested_tensor(torch.unbind(k), device=device, dtype=dtype)
207207
v = k
208208
else:
209-
k = torch.nested_tensor(torch.unbind(k), device=device, dtype=dtype)
210-
v = torch.nested_tensor(torch.unbind(v), device=device, dtype=dtype)
209+
k = torch.nested.nested_tensor(torch.unbind(k), device=device, dtype=dtype)
210+
v = torch.nested.nested_tensor(torch.unbind(v), device=device, dtype=dtype)
211211

212212
ynpt, weight_npt = npt(
213213
q, k, v, key_padding_mask=mask if use_padding and not use_nt else None

0 commit comments

Comments
 (0)