Skip to content

Commit 1ec30a6

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Add offsets-based reduction to segment_reduce (CPU, CUDA)
Pull Request resolved: pytorch#78907 Approved by: https://github.com/cpuhrsch
1 parent c978b60 commit 1ec30a6

File tree

9 files changed

+489
-171
lines changed

9 files changed

+489
-171
lines changed

aten/src/ATen/native/SegmentReduce.cpp

+257-82
Large diffs are not rendered by default.

aten/src/ATen/native/SegmentReduce.h

+22-4
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,41 @@ namespace native {
1111

1212
enum SegmentReductionType { MAX, MEAN, MIN, SUM, PROD};
1313

14-
using segment_reduce_fn = Tensor (*)(
14+
using segment_reduce_lengths_fn = Tensor (*)(
1515
SegmentReductionType,
1616
const Tensor&,
1717
const Tensor&,
1818
int64_t,
1919
const c10::optional<Scalar>&);
20-
DECLARE_DISPATCH(segment_reduce_fn, _segment_reduce_stub);
20+
DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub);
2121

22-
using segment_reduce_backward_fn = Tensor (*)(
22+
using segment_reduce_offsets_fn = Tensor (*)(
23+
SegmentReductionType,
24+
const Tensor&,
25+
const Tensor&,
26+
int64_t,
27+
const c10::optional<Scalar>&);
28+
DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub);
29+
30+
using segment_reduce_lengths_backward_fn = Tensor (*)(
31+
const Tensor&,
32+
const Tensor&,
33+
const Tensor&,
34+
SegmentReductionType,
35+
const Tensor&,
36+
int64_t,
37+
const c10::optional<Scalar>&);
38+
DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub);
39+
40+
using segment_reduce_offsets_backward_fn = Tensor (*)(
2341
const Tensor&,
2442
const Tensor&,
2543
const Tensor&,
2644
SegmentReductionType,
2745
const Tensor&,
2846
int64_t,
2947
const c10::optional<Scalar>&);
30-
DECLARE_DISPATCH(segment_reduce_backward_fn, _segment_reduce_backward_stub);
48+
DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub);
3149

3250
} // namespace native
3351
} // namespace at

aten/src/ATen/native/cuda/SegmentReduce.cu

+100-34
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Tensor _get_complete_sum(const Tensor& lengths) {
7070
offsets[0].zero_();
7171

7272
AT_DISPATCH_INDEX_TYPES(
73-
lengths.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
73+
lengths.scalar_type(), "_segment_reduce_cuda_lengths_offsets_backward_kernel1", ([&] {
7474
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
7575
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
7676
at::cuda::cub::inclusive_sum(
@@ -278,23 +278,33 @@ __global__ void segment_reduce_backward_kernel(
278278
}
279279
} // namespace
280280

281-
Tensor _segment_reduce_cuda_backward_kernel(
281+
Tensor _segment_reduce_lengths_offsets_backward_cuda_kernel(
282282
const Tensor& grad_contig,
283283
const Tensor& output_contig,
284284
const Tensor& data_contig,
285285
SegmentReductionType reduction,
286-
const Tensor& lengths_contig,
286+
const Tensor& lengths_or_offsets_contig,
287287
int64_t axis,
288-
const c10::optional<Scalar>& initial) {
289-
axis = lengths_contig.dim() - 1;
290-
int64_t segment_count = lengths_contig.size(axis);
291-
int64_t lengths_stride_axis = lengths_contig.stride(axis);
288+
const c10::optional<Scalar>& initial,
289+
bool is_offsets_like) {
290+
axis = lengths_or_offsets_contig.dim() - 1;
291+
int64_t segment_count = is_offsets_like ?
292+
lengths_or_offsets_contig.size(axis) - 1 :
293+
lengths_or_offsets_contig.size(axis);
294+
int64_t lengths_stride_axis = lengths_or_offsets_contig.stride(axis);
292295
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
293296

294-
auto zeros_shape = lengths_contig.sizes().vec();
295-
zeros_shape[axis] = 1;
296-
auto offsets = at::cat({at::zeros(zeros_shape, lengths_contig.options()), lengths_contig}, axis);
297-
offsets.cumsum_(axis);
297+
auto offsets = lengths_or_offsets_contig;
298+
auto lengths = lengths_or_offsets_contig;
299+
if (is_offsets_like) {
300+
lengths = lengths.diff();
301+
} else {
302+
// _get_complete_sum only supports 1D
303+
auto zeros_shape = offsets.sizes().vec();
304+
zeros_shape[axis] = 1;
305+
offsets = at::cat({at::zeros(zeros_shape, offsets.options()), offsets}, axis);
306+
offsets.cumsum_(axis);
307+
}
298308

299309
// outer_offset is the size of the outer dimensions of output (before axis)
300310
// inner_offset is the size of the inner dimensions of output (after axis)
@@ -318,8 +328,8 @@ Tensor _segment_reduce_cuda_backward_kernel(
318328
auto offsets_stride_axis = offsets.stride(axis);
319329

320330
AT_DISPATCH_INDEX_TYPES(
321-
lengths_contig.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
322-
const auto* lengths_data = lengths_contig.data_ptr<index_t>();
331+
lengths_or_offsets_contig.scalar_type(), "_segment_reduce_cuda_lengths_offsets_backward_kernel1", ([&] {
332+
const auto* lengths_data = lengths.data_ptr<index_t>();
323333
auto* offsets_data = offsets.data_ptr<index_t>();
324334

325335
// TODO: Switch to TensorIterator for better maintainablility and
@@ -371,27 +381,59 @@ Tensor _segment_reduce_cuda_backward_kernel(
371381
return grad_input;
372382
}
373383
374-
Tensor _segment_reduce_cuda_kernel(
375-
SegmentReductionType reduction,
376-
const Tensor& data,
377-
const Tensor& lengths,
378-
int64_t axis,
379-
const c10::optional<Scalar>& initial) {
380-
// data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
381-
TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous.");
382-
TORCH_CHECK(lengths.is_contiguous(), "Expected lengths to be contiguous.");
383-
axis = lengths.dim() - 1;
384-
int64_t segment_count = lengths.size(axis);
385-
int64_t lengths_stride_axis = lengths.stride(axis);
384+
Tensor _segment_reduce_lengths_backward_cuda_kernel(
385+
const Tensor& grad_contig,
386+
const Tensor& output_contig,
387+
const Tensor& data_contig,
388+
SegmentReductionType reduction,
389+
const Tensor& lengths_contig,
390+
int64_t axis,
391+
const c10::optional<Scalar>& initial) {
392+
return _segment_reduce_lengths_offsets_backward_cuda_kernel(
393+
grad_contig, output_contig, data_contig, reduction, lengths_contig, axis, initial, /*is_offsets_like=*/false);
394+
}
395+
396+
Tensor _segment_reduce_offsets_backward_cuda_kernel(
397+
const Tensor& grad_contig,
398+
const Tensor& output_contig,
399+
const Tensor& data_contig,
400+
SegmentReductionType reduction,
401+
const Tensor& offsets_contig,
402+
int64_t axis,
403+
const c10::optional<Scalar>& initial) {
404+
return _segment_reduce_lengths_offsets_backward_cuda_kernel(
405+
grad_contig, output_contig, data_contig, reduction, offsets_contig, axis, initial, /*is_offsets_like=*/true);
406+
}
407+
408+
Tensor _segment_reduce_lengths_offsets_cuda_kernel(
409+
SegmentReductionType reduction,
410+
const Tensor& data,
411+
const Tensor& lengths_or_offsets,
412+
int64_t axis,
413+
const c10::optional<Scalar>& initial,
414+
bool is_offsets_like) {
415+
// data and lengths_or_offsets should be contiguous from the call to .contiguous in segment_reduce_kernel
416+
TORCH_CHECK(data.is_contiguous());
417+
TORCH_CHECK(lengths_or_offsets.is_contiguous());
418+
axis = lengths_or_offsets.dim() - 1;
419+
int64_t segment_count = is_offsets_like ? lengths_or_offsets.size(axis) - 1 : lengths_or_offsets.size(axis);
420+
int64_t lengths_stride_axis = lengths_or_offsets.stride(axis);
386421
auto output_shape = data.sizes().vec();
387422
output_shape[axis] = segment_count;
388423
auto output = at::empty(output_shape, data.options());
389424
390-
// _get_complete_sum only supports 1D?
391-
auto zeros_shape = lengths.sizes().vec();
392-
zeros_shape[axis] = 1;
393-
auto offsets = at::cat({at::zeros(zeros_shape, lengths.options()), lengths}, axis);
394-
offsets.cumsum_(axis);
425+
426+
auto offsets = lengths_or_offsets;
427+
auto lengths = lengths_or_offsets;
428+
if (is_offsets_like) {
429+
lengths = lengths.diff();
430+
} else {
431+
// _get_complete_sum only supports 1D
432+
auto zeros_shape = offsets.sizes().vec();
433+
zeros_shape[axis] = 1;
434+
offsets = at::cat({at::zeros(zeros_shape, offsets.options()), offsets}, axis);
435+
offsets.cumsum_(axis);
436+
}
395437
396438
// outer_offset is the size of the outer dimensions of output (before axis)
397439
// inner_offset is the size of the inner dimensions of output (after axis)
@@ -416,7 +458,7 @@ Tensor _segment_reduce_cuda_kernel(
416458
auto offsets_stride_axis = offsets.stride(axis);
417459
418460
AT_DISPATCH_INDEX_TYPES(
419-
lengths.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] {
461+
lengths_or_offsets.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] {
420462
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
421463
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
422464
AT_DISPATCH_FLOATING_TYPES_AND2(
@@ -549,10 +591,34 @@ Tensor _segment_reduce_cuda_kernel(
549591
return output;
550592
}
551593
552-
REGISTER_DISPATCH(_segment_reduce_stub, &_segment_reduce_cuda_kernel);
594+
Tensor _segment_reduce_lengths_cuda_kernel(
595+
SegmentReductionType reduction,
596+
const Tensor& data,
597+
const Tensor& lengths,
598+
int64_t axis,
599+
const c10::optional<Scalar>& initial) {
600+
return _segment_reduce_lengths_offsets_cuda_kernel(
601+
reduction, data, lengths, axis, initial, /*is_offsets_like=*/false);
602+
}
603+
604+
Tensor _segment_reduce_offsets_cuda_kernel(
605+
SegmentReductionType reduction,
606+
const Tensor& data,
607+
const Tensor& offsets,
608+
int64_t axis,
609+
const c10::optional<Scalar>& initial) {
610+
return _segment_reduce_lengths_offsets_cuda_kernel(
611+
reduction, data, offsets, axis, initial, /*is_offsets_like=*/true);
612+
}
613+
614+
REGISTER_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cuda_kernel);
615+
REGISTER_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cuda_kernel);
616+
REGISTER_DISPATCH(
617+
_segment_reduce_lengths_backward_stub,
618+
&_segment_reduce_lengths_backward_cuda_kernel);
553619
REGISTER_DISPATCH(
554-
_segment_reduce_backward_stub,
555-
&_segment_reduce_cuda_backward_kernel);
620+
_segment_reduce_offsets_backward_stub,
621+
&_segment_reduce_offsets_backward_cuda_kernel);
556622
557623
} // namespace native
558624
} // namespace at

aten/src/ATen/native/native_functions.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -11987,12 +11987,12 @@
1198711987
dispatch:
1198811988
CompositeExplicitAutograd: _test_warn_in_autograd
1198911989

11990-
- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
11990+
- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
1199111991
variants: function
1199211992
dispatch:
1199311993
CPU, CUDA: segment_reduce_kernel
1199411994

11995-
- func: _segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, int axis=0, Scalar? initial=None) -> Tensor
11995+
- func: _segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None) -> Tensor
1199611996
variants: function
1199711997
dispatch:
1199811998
CPU, CUDA: _segment_reduce_backward_kernel

test/forward_backward_compatibility/check_forward_backward_compatibility.py

+2
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@
143143
("aten::_csr_to_block_csr", datetime.date(2022, 5, 20)),
144144
("aten::_weight_norm_cuda_interface", datetime.date(9999, 1, 1)),
145145
("aten::_weight_norm_cuda_interface_backward", datetime.date(9999, 1, 1)),
146+
("aten::segment_reduce", datetime.date(2022, 6, 30)),
147+
("aten::_segment_reduce_backward", datetime.date(2022, 6, 30)),
146148
# TODO: FIXME: prims shouldn't be checked
147149
("prims::.*", datetime.date(9999, 1, 1)),
148150
]

0 commit comments

Comments
 (0)