Skip to content

Commit 5861279

Browse files
Revert "Add support for index_put_ in NT (pytorch#135722)"
This reverts commit b4836e5. Reverted pytorch#135722 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it is failing on ROCm ([comment](pytorch#135722 (comment)))
1 parent 1797a20 commit 5861279

File tree

3 files changed

+1
-217
lines changed

3 files changed

+1
-217
lines changed

test/test_nestedtensor.py

-60
Original file line numberDiff line numberDiff line change
@@ -6194,34 +6194,6 @@ def test_copy_(self, device):
61946194
):
61956195
a.copy_(b)
61966196

6197-
# This can't happen in the opinfo tests due to subprocess creation
6198-
def test_index_put_error(self, device):
6199-
import subprocess
6200-
6201-
with self.subTest():
6202-
r = subprocess.call(
6203-
[
6204-
sys.executable,
6205-
"-c",
6206-
"""\
6207-
import torch
6208-
offsets = torch.tensor([0, 2, 5, 7], device='cuda')
6209-
lengths = torch.tensor([2, 2, 2], device='cuda')
6210-
indices = [
6211-
torch.tensor([0, 1, 2], device='cuda'),
6212-
torch.tensor([0, 2, 1], device='cuda'),
6213-
torch.tensor([0, 0, 0], device='cuda'),
6214-
]
6215-
a = torch.nested.nested_tensor_from_jagged(
6216-
torch.zeros(7, 3, device='cuda'), offsets, lengths
6217-
)
6218-
a[indices] = 1.0
6219-
torch.cuda.synchronize()
6220-
""",
6221-
]
6222-
)
6223-
self.assertTrue(r != 0)
6224-
62256197
@skipIfTorchDynamo("Dynamo doesn't know how to trace prof.events()")
62266198
def test_profiler_sequence_nr(self):
62276199
with torch.profiler.profile() as prof:
@@ -7943,12 +7915,6 @@ def test_forward(self, device, dtype, op):
79437915
out_ref = op.ref(op, sample)
79447916
self.assertEqualIgnoringNestedInts(out, out_ref)
79457917

7946-
# TODO: Revisit once https://github.com/pytorch/pytorch/pull/138369 lands
7947-
# TODO: Add xfails for other inplace ops instead of hardcoding
7948-
if op.inplace_variant and "index_put" in op.full_name:
7949-
op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
7950-
self.assertEqualIgnoringNestedInts(sample.input, out_ref)
7951-
79527918
@withXFails(BACKWARD_FAILURES)
79537919
@ops(
79547920
[op for op in njt_op_db if op.supports_njt and op.supports_autograd],
@@ -8004,32 +7970,6 @@ def f(*args, **kwargs):
80047970
else:
80057971
self.assertEqual(out_compile, out_ref)
80067972

8007-
# TODO: Revisit once https://github.com/pytorch/pytorch/pull/138369 lands
8008-
# TODO: Add xfails for other inplace ops instead of hardcoding
8009-
if op.inplace_variant and "index_put" in op.full_name:
8010-
op_fn = op.inplace_variant
8011-
8012-
def in_f(*args, **kwargs):
8013-
return op_fn(*args, **kwargs)
8014-
8015-
compiled_in_f = torch.compile(
8016-
in_f, fullgraph=True, backend="aot_eager_decomp_partition"
8017-
)
8018-
8019-
if sample.input.is_contiguous():
8020-
compiled_in_f(sample.input, *sample.args, **sample.kwargs)
8021-
if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY:
8022-
self.assertEqualIgnoringNestedInts(sample.input, out_ref)
8023-
else:
8024-
self.assertEqual(sample.input, out_ref)
8025-
else:
8026-
# see https://github.com/pytorch/pytorch/issues/106456
8027-
with self.assertRaisesRegex(
8028-
RuntimeError,
8029-
"Mutations on non-contiguous inputs are currently not allowed on tensor subclasses",
8030-
):
8031-
compiled_in_f(sample.input, *sample.args, **sample.kwargs)
8032-
80337973
@withXFails(COMPILE_BACKWARD_FAILURES)
80347974
@ops(
80357975
[op for op in njt_op_db if op.supports_njt and op.supports_autograd],

torch/nested/_internal/ops.py

-93
Original file line numberDiff line numberDiff line change
@@ -1558,99 +1558,6 @@ def slice_tensor(func, *args, **kwargs):
15581558
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
15591559

15601560

1561-
@register_jagged_func(
1562-
torch.ops.aten.index_put.default,
1563-
"input: jt_all, indices: any, values: t, accumulate: any?",
1564-
)
1565-
@register_jagged_func(
1566-
torch.ops.aten.index_put_.default,
1567-
"input: jt_all, indices: any, values: t, accumulate: any?",
1568-
)
1569-
def index_put_(func, *args, **kwargs):
1570-
_, new_kwargs = normalize_function( # type: ignore[misc]
1571-
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1572-
)
1573-
1574-
inp: NestedTensor = new_kwargs.pop("input")
1575-
1576-
# For index_put_ to work, we add together the indices of the ragged dimension
1577-
# and the batch dimension, adding the offsets of each ragged dimension to its
1578-
# indices
1579-
1580-
indices = new_kwargs.pop("indices")
1581-
1582-
assert len(indices) <= inp.dim()
1583-
1584-
if len(indices) < inp._ragged_idx + 1:
1585-
if not inp.is_contiguous():
1586-
raise RuntimeError(
1587-
"index_put(): If ragged dimension is not part of indices, this only works on contiguous NJTs"
1588-
)
1589-
# Ragged dim is NOT part of indices, we need to pad the nested tensor to apply func
1590-
from .nested_tensor import nested_from_padded
1591-
1592-
min_seqlen = inp._maybe_min_seqlen
1593-
max_seqlen = inp._maybe_max_seqlen
1594-
padded_max_S = max_seqlen
1595-
total_L = inp._values.shape[inp._ragged_idx - 1]
1596-
if padded_max_S is None:
1597-
# use upper bound on max seqlen if it's not present
1598-
padded_max_S = total_L
1599-
1600-
padded_shape = (
1601-
*inp.shape[: inp._ragged_idx],
1602-
padded_max_S,
1603-
*inp.shape[inp._ragged_idx + 1 :],
1604-
)
1605-
padded_inp = inp.to_padded_tensor(0.0, output_size=padded_shape)
1606-
new_njt = nested_from_padded(
1607-
func(padded_inp, indices, **new_kwargs),
1608-
offsets=inp._offsets,
1609-
ragged_idx=inp._ragged_idx,
1610-
sum_S=total_L,
1611-
min_seqlen=min_seqlen,
1612-
max_seqlen=max_seqlen,
1613-
)
1614-
1615-
if func == torch.ops.aten.index_put_.default:
1616-
inp._values.copy_(new_njt.values())
1617-
return inp
1618-
return new_njt
1619-
1620-
# We can run on the underlying values directly
1621-
1622-
# Validate indices
1623-
if inp.lengths() is None:
1624-
lengths = inp.offsets().diff()
1625-
else:
1626-
lengths = inp.lengths()
1627-
torch._assert_async(
1628-
torch.all(indices[inp._ragged_idx] < lengths),
1629-
"Some indices in the ragged dimension are out of bounds!",
1630-
)
1631-
1632-
# Recompute indices for _values
1633-
ragged_indices = inp.offsets()[indices[0]] + indices[inp._ragged_idx]
1634-
func_indices = (
1635-
# before ragged dim
1636-
indices[1 : inp._ragged_idx]
1637-
# ragged dim (combined with batch)
1638-
+ [ragged_indices]
1639-
# after ragged dim
1640-
+ indices[inp._ragged_idx + 1 :]
1641-
)
1642-
1643-
if func == torch.ops.aten.index_put_.default:
1644-
inp._values = func(inp._values, func_indices, **new_kwargs)
1645-
return inp
1646-
1647-
return NestedTensor(
1648-
func(inp._values, func_indices, **new_kwargs),
1649-
**extract_kwargs(inp),
1650-
lengths=inp.lengths(),
1651-
)
1652-
1653-
16541561
@register_jagged_func(
16551562
torch.ops.aten.convolution.default,
16561563
"input: jt, weight: t, bias: t?, stride: any, padding: any, "

torch/testing/_internal/opinfo/definitions/nested.py

+1-64
Original file line numberDiff line numberDiff line change
@@ -106,29 +106,6 @@ def _slice_input(t, i=i, inp=nt_inp):
106106
args = tree_map(_slice_input, sample.args)
107107
kwargs = tree_map(_slice_input, sample.kwargs)
108108

109-
# Handle indices in index_put
110-
if "index_put" in op.full_name and "indices" in kwargs:
111-
if len(kwargs["indices"]) > 1:
112-
# If after unrolling we still have indices left, use them
113-
kwargs["indices"] = [t[i] for t in kwargs["indices"][1:]]
114-
else:
115-
# If no indices are left, create them so they match the NJT implementation
116-
sequence_put = kwargs["indices"][0].tolist()
117-
if i in sequence_put:
118-
kwargs["indices"] = [
119-
torch.tensor(
120-
list(range(inp.shape[0])),
121-
dtype=torch.int32,
122-
device=kwargs["indices"][0].device,
123-
)
124-
]
125-
else:
126-
kwargs["indices"] = [
127-
torch.tensor(
128-
[], dtype=torch.int32, device=kwargs["indices"][0].device
129-
)
130-
]
131-
132109
from torch._prims_common import canonicalize_dims
133110

134111
# Need to adjust dim to apply on NJT component
@@ -138,6 +115,7 @@ def _slice_input(t, i=i, inp=nt_inp):
138115

139116
# TODO: handle this
140117
assert "dims" not in kwargs
118+
141119
out_ref_component = op.op(inp, *args, **kwargs)
142120

143121
# TODO: handle list / tuple / non-NJT outputs
@@ -471,46 +449,6 @@ def sample_inputs_nn_functional_embedding(
471449
)
472450

473451

474-
def sample_inputs_index_put(
475-
op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
476-
):
477-
for njt in _sample_njts(
478-
device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
479-
):
480-
for dim in range(njt.dim()):
481-
indices = [
482-
torch.tensor(list(range(njt.size(0))), device=njt.device),
483-
*[
484-
torch.tensor([0] * njt.size(0), device=njt.device)
485-
for _ in range(dim - 1)
486-
],
487-
]
488-
yield SampleInput(
489-
njt.clone().detach(),
490-
kwargs={
491-
"indices": indices,
492-
"values": torch.tensor(1.0, device=njt.device),
493-
},
494-
)
495-
496-
# Non-cont NJT for completeness
497-
offsets = torch.tensor([0, 2, 5, 7], device=device)
498-
lengths = torch.tensor([2, 2, 2], device=device)
499-
indices = [
500-
torch.tensor([0, 1, 2], device=device),
501-
torch.tensor([0, 1, 1], device=device),
502-
torch.tensor([0, 0, 0], device=device),
503-
]
504-
a = torch.nested.nested_tensor_from_jagged(
505-
torch.zeros(7, 3, device=device), offsets, lengths
506-
)
507-
508-
yield SampleInput(
509-
a.clone().detach(),
510-
kwargs={"indices": indices, "values": torch.tensor(1.0, device=a.device)},
511-
)
512-
513-
514452
def sample_inputs_nn_functional_embedding_bag(
515453
op_info, device, dtype, requires_grad, **kwargs
516454
):
@@ -653,7 +591,6 @@ def sample_inputs_nn_functional_rms_norm(
653591
"to": sample_inputs_to,
654592
"matmul": sample_inputs_matmul,
655593
"masked_select": sample_inputs_masked_select,
656-
"index_put": sample_inputs_index_put,
657594
}
658595

659596
njt_references = {

0 commit comments

Comments
 (0)