Skip to content

Commit 19918a1

Browse files
jbschlosserpytorchmergebot
authored andcommitted
Fix autograd.Function + NJT when an output grad is None (pytorch#136875)
For `autograd.Function`, the engine will try to allocate correctly-shaped zeros for `None` grads (i.e. in the case where the output isn't used downstream). It determines the shape of these zeros from the `VariableInfo` entry, which is derived from the forward output shape. For the NJT forward output case, the size info stored will contain a nested int, and calling `zeros()` with this size throws: ``` RuntimeError: .../build/aten/src/ATen/RegisterCPU.cpp:5260: SymIntArrayRef expected to contain only concrete integers ``` This PR fixes this by storing the full tensor in the `VariableInfo` for the nested case and calling `zeros_like()` to allocate correctly-shaped zeros. This is pretty inefficient; ideally we would want to save just the NJT shape and be able to construct zeros from it, but this requires factory function support for nested ints (WIP). So this is a short-term fix until we have that. Pull Request resolved: pytorch#136875 Approved by: https://github.com/soulitzer, https://github.com/huydhn
1 parent 197601e commit 19918a1

File tree

4 files changed

+52
-4
lines changed

4 files changed

+52
-4
lines changed

test/test_nestedtensor.py

+30
Original file line numberDiff line numberDiff line change
@@ -7055,6 +7055,36 @@ def test_noncontiguous_to(self, device, dtype, contiguity):
70557055
if nt._lengths is not None:
70567056
self.assertEqual(nt3._lengths.device, other_device)
70577057

7058+
@dtypes(torch.float32)
7059+
def test_autograd_function_with_None_grad(self, device, dtype):
7060+
class MyFunction(torch.autograd.Function):
7061+
@staticmethod
7062+
def forward(ctx, inp):
7063+
ctx.save_for_backward(inp)
7064+
out1 = inp + 1
7065+
out2 = inp * 2
7066+
return out1, out2
7067+
7068+
@staticmethod
7069+
def backward(ctx, grad_out1, grad_out2):
7070+
(inp,) = ctx.saved_tensors
7071+
return grad_out1 + grad_out2
7072+
7073+
f = MyFunction.apply
7074+
nt = random_nt_from_dims(
7075+
[5, None, 10],
7076+
device=device,
7077+
dtype=dtype,
7078+
layout=torch.jagged,
7079+
requires_grad=True,
7080+
)
7081+
7082+
# Only use one of the autograd.Function outputs downstream so that the grad
7083+
# for the other output is None. We're testing that the engine can allocate
7084+
# correctly-shaped (NJT) zeros for the grad of the other output in this case.
7085+
(out1, _) = f(nt)
7086+
out1.backward(torch.ones_like(out1))
7087+
70587088
@dtypes(torch.float64, torch.float32, torch.half)
70597089
def test_jagged_padded_dense_conversion_kernels(self, device, dtype):
70607090
values = torch.randn(10, 5, device=device, dtype=dtype)

torch/csrc/autograd/python_function.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -733,8 +733,18 @@ static void _wrap_outputs(
733733
PyTuple_SetItem(outputs, i, obj);
734734
} else {
735735
if (is_executable) {
736+
// If one of the grad outputs is undefined, a correctly-shaped zeros
737+
// should be used instead. To construct these for NJT, zeros_like() must
738+
// be used until we have factory function support.
736739
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
737-
self->output_info.emplace_back(*wrapped_outputs[i]);
740+
bool is_differentiable =
741+
(non_differentiable.count(
742+
wrapped_outputs[i]->unsafeGetTensorImpl()) == 0 &&
743+
isDifferentiableType(wrapped_outputs[i]->scalar_type()));
744+
bool use_zeros_like = is_differentiable && num_outputs > 1 &&
745+
wrapped_outputs[i]->is_nested();
746+
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
747+
self->output_info.emplace_back(*wrapped_outputs[i], use_zeros_like);
738748
}
739749
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
740750
PyTuple_SetItem(outputs, i, THPVariable_Wrap(*wrapped_outputs[i]));

torch/csrc/autograd/variable_info.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,33 @@
22
#include <ATen/Functions.h>
33
#else
44
#include <ATen/ops/zeros.h>
5+
#include <ATen/ops/zeros_like.h>
56
#endif
67

78
#include <torch/csrc/autograd/variable.h>
89
#include <torch/csrc/autograd/variable_info.h>
910

1011
namespace torch::autograd {
1112

12-
VariableInfo::VariableInfo(const Variable& var)
13+
VariableInfo::VariableInfo(const Variable& var, bool use_zeros_like)
1314
: layout(var.layout()),
1415
device(var.device()),
1516
scalar_type(var.scalar_type()),
1617
size(var.sym_sizes().vec()),
1718
requires_grad(var.requires_grad()),
18-
is_empty(false) {}
19+
is_empty(false),
20+
the_var(
21+
use_zeros_like ? std::optional<Variable>(var.detach())
22+
: std::nullopt) {}
1923

2024
VariableInfo::VariableInfo() : requires_grad(false), is_empty(true) {}
2125

2226
Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
2327
if (is_empty) {
2428
// Return undefined tensor.
2529
return at::Tensor();
30+
} else if (the_var.has_value()) {
31+
return at::zeros_like(*the_var);
2632
} else {
2733
return at::zeros_symint(
2834
size, at::TensorOptions(scalar_type).device(device).layout(layout));

torch/csrc/autograd/variable_info.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace torch::autograd {
66

77
struct TORCH_API VariableInfo {
88
explicit VariableInfo();
9-
explicit VariableInfo(const Variable& var);
9+
explicit VariableInfo(const Variable& var, bool use_zeros_like = false);
1010

1111
Variable zeros(at::OptionalDeviceGuard& device_guard) const;
1212

@@ -16,6 +16,8 @@ struct TORCH_API VariableInfo {
1616
std::vector<c10::SymInt> size;
1717
bool requires_grad;
1818
bool is_empty;
19+
// needed for e.g. NJTs since they only support zeros_like()
20+
std::optional<Variable> the_var;
1921
};
2022

2123
} // namespace torch::autograd

0 commit comments

Comments
 (0)