You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add python nested_tensor and as_nested_tensor constructors in torch.nested (pytorch#85593)
Remove `torch.nested_tensor` which has erroneous behavior wrt gradients (could be either leaf or not leaf). Introduce `torch.nested.nested_tensor` and `torch.nested.as_nested_tensor` in the vein of `torch.tensor` and `torch.as_tensor`. Done in nested `__init__.py` for now but can move to pybind in future (when we want to load from numpy/nested lists ).
Discussed offline with @cpuhrsch and pybind constructor (pytorch#85536) was more gnarly than expected, so we can move to that when we do need loading from numpy etc.
Differential Revision: [D39806622](https://our.internmc.facebook.com/intern/diff/D39806622)
Pull Request resolved: pytorch#85593
Approved by: https://github.com/drisspg, https://github.com/cpuhrsch
Copy file name to clipboardexpand all lines: docs/source/nested.rst
+9-7
Original file line number
Diff line number
Diff line change
@@ -23,7 +23,7 @@ Construction is straightforward and involves passing a list of Tensors to the co
23
23
tensor([0, 1, 2])
24
24
>>> b
25
25
tensor([3, 4, 5, 6, 7])
26
-
>>> nt = torch.nested_tensor([a, b])
26
+
>>> nt = torch.nested.nested_tensor([a, b])
27
27
>>> nt
28
28
nested_tensor([
29
29
tensor([0, 1, 2]),
@@ -32,7 +32,7 @@ nested_tensor([
32
32
33
33
Data type and device can be chosen via the usual keyword arguments.
34
34
35
-
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
35
+
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32, device="cuda")
36
36
>>> nt
37
37
nested_tensor([
38
38
tensor([0., 1., 2.], device='cuda:0'),
@@ -43,15 +43,15 @@ In order to form a valid NestedTensor the passed Tensors also all need to match
43
43
44
44
>>> a = torch.randn(3, 50, 70) # image 1
45
45
>>> b = torch.randn(3, 128, 64) # image 2
46
-
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
46
+
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
47
47
>>> nt.dim()
48
48
4
49
49
50
50
If one of the dimensions don't match, the constructor throws an error.
51
51
52
52
>>> a = torch.randn(50, 128) # text 1
53
53
>>> b = torch.randn(3, 128, 64) # image 2
54
-
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
54
+
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
55
55
Traceback (most recent call last):
56
56
File "<stdin>", line 1, in <module>
57
57
RuntimeError: All Tensors given to nested_tensor must have the same dimension. Found dimension 3 for Tensor at index 1 and dimension 2 for Tensor at index 0.
@@ -73,7 +73,7 @@ Even though a NestedTensor does not support .size() (or .shape), it supports .si
73
73
74
74
>>> a = torch.randn(50, 128) # text 1
75
75
>>> b = torch.randn(32, 128) # text 2
76
-
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
76
+
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
77
77
>>> nt.size(0)
78
78
2
79
79
>>> nt.size(1)
@@ -86,7 +86,7 @@ RuntimeError: Given dimension 1 is irregular and does not have a size.
86
86
If all dimensions are regular, the NestedTensor is intended to be semantically indistinguishable from a regular torch.Tensor.
87
87
88
88
>>> a = torch.randn(20, 128) # text 1
89
-
>>> nt = torch.nested_tensor([a, a], dtype=torch.float32)
89
+
>>> nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)
90
90
>>> nt.size(0)
91
91
2
92
92
>>> nt.size(1)
@@ -112,7 +112,7 @@ unbind allows you to retrieve a view of the constituents.
112
112
>>> import torch
113
113
>>> a = torch.randn(2, 3)
114
114
>>> b = torch.randn(3, 4)
115
-
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32)
115
+
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
116
116
>>> nt
117
117
nested_tensor([
118
118
tensor([[ 1.2286, -1.2343, -1.4842],
@@ -149,4 +149,6 @@ The following functions are related to nested tensors:
0 commit comments