@@ -70,7 +70,7 @@ Tensor _get_complete_sum(const Tensor& lengths) {
70
70
offsets[0 ].zero_ ();
71
71
72
72
AT_DISPATCH_INDEX_TYPES (
73
- lengths.scalar_type (), " _segment_reduce_cuda_backward_kernel1 " , ([&] {
73
+ lengths.scalar_type (), " _segment_reduce_cuda_lengths_offsets_backward_kernel1 " , ([&] {
74
74
auto * lengths_data_ptr = lengths.data_ptr <index_t >();
75
75
auto * offsets_data_ptr = offsets.data_ptr <index_t >();
76
76
at::cuda::cub::inclusive_sum (
@@ -278,23 +278,33 @@ __global__ void segment_reduce_backward_kernel(
278
278
}
279
279
} // namespace
280
280
281
- Tensor _segment_reduce_cuda_backward_kernel (
281
+ Tensor _segment_reduce_lengths_offsets_backward_cuda_kernel (
282
282
const Tensor& grad_contig,
283
283
const Tensor& output_contig,
284
284
const Tensor& data_contig,
285
285
SegmentReductionType reduction,
286
- const Tensor& lengths_contig ,
286
+ const Tensor& lengths_or_offsets_contig ,
287
287
int64_t axis,
288
- const c10::optional<Scalar>& initial) {
289
- axis = lengths_contig.dim () - 1 ;
290
- int64_t segment_count = lengths_contig.size (axis);
291
- int64_t lengths_stride_axis = lengths_contig.stride (axis);
288
+ const c10::optional<Scalar>& initial,
289
+ bool is_offsets_like) {
290
+ axis = lengths_or_offsets_contig.dim () - 1 ;
291
+ int64_t segment_count = is_offsets_like ?
292
+ lengths_or_offsets_contig.size (axis) - 1 :
293
+ lengths_or_offsets_contig.size (axis);
294
+ int64_t lengths_stride_axis = lengths_or_offsets_contig.stride (axis);
292
295
auto grad_input = at::zeros ({data_contig.sizes ()}, grad_contig.options ());
293
296
294
- auto zeros_shape = lengths_contig.sizes ().vec ();
295
- zeros_shape[axis] = 1 ;
296
- auto offsets = at::cat ({at::zeros (zeros_shape, lengths_contig.options ()), lengths_contig}, axis);
297
- offsets.cumsum_ (axis);
297
+ auto offsets = lengths_or_offsets_contig;
298
+ auto lengths = lengths_or_offsets_contig;
299
+ if (is_offsets_like) {
300
+ lengths = lengths.diff ();
301
+ } else {
302
+ // _get_complete_sum only supports 1D
303
+ auto zeros_shape = offsets.sizes ().vec ();
304
+ zeros_shape[axis] = 1 ;
305
+ offsets = at::cat ({at::zeros (zeros_shape, offsets.options ()), offsets}, axis);
306
+ offsets.cumsum_ (axis);
307
+ }
298
308
299
309
// outer_offset is the size of the outer dimensions of output (before axis)
300
310
// inner_offset is the size of the inner dimensions of output (after axis)
@@ -318,8 +328,8 @@ Tensor _segment_reduce_cuda_backward_kernel(
318
328
auto offsets_stride_axis = offsets.stride (axis);
319
329
320
330
AT_DISPATCH_INDEX_TYPES (
321
- lengths_contig .scalar_type (), " _segment_reduce_cuda_backward_kernel1 " , ([&] {
322
- const auto * lengths_data = lengths_contig .data_ptr <index_t >();
331
+ lengths_or_offsets_contig .scalar_type (), " _segment_reduce_cuda_lengths_offsets_backward_kernel1 " , ([&] {
332
+ const auto * lengths_data = lengths .data_ptr <index_t >();
323
333
auto * offsets_data = offsets.data_ptr <index_t >();
324
334
325
335
// TODO: Switch to TensorIterator for better maintainablility and
@@ -371,27 +381,59 @@ Tensor _segment_reduce_cuda_backward_kernel(
371
381
return grad_input;
372
382
}
373
383
374
- Tensor _segment_reduce_cuda_kernel (
375
- SegmentReductionType reduction,
376
- const Tensor& data,
377
- const Tensor& lengths,
378
- int64_t axis,
379
- const c10::optional<Scalar>& initial) {
380
- // data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
381
- TORCH_CHECK (data.is_contiguous (), " Expected data to be contiguous." );
382
- TORCH_CHECK (lengths.is_contiguous (), " Expected lengths to be contiguous." );
383
- axis = lengths.dim () - 1 ;
384
- int64_t segment_count = lengths.size (axis);
385
- int64_t lengths_stride_axis = lengths.stride (axis);
384
+ Tensor _segment_reduce_lengths_backward_cuda_kernel (
385
+ const Tensor& grad_contig,
386
+ const Tensor& output_contig,
387
+ const Tensor& data_contig,
388
+ SegmentReductionType reduction,
389
+ const Tensor& lengths_contig,
390
+ int64_t axis,
391
+ const c10::optional<Scalar>& initial) {
392
+ return _segment_reduce_lengths_offsets_backward_cuda_kernel (
393
+ grad_contig, output_contig, data_contig, reduction, lengths_contig, axis, initial, /* is_offsets_like=*/ false );
394
+ }
395
+
396
+ Tensor _segment_reduce_offsets_backward_cuda_kernel (
397
+ const Tensor& grad_contig,
398
+ const Tensor& output_contig,
399
+ const Tensor& data_contig,
400
+ SegmentReductionType reduction,
401
+ const Tensor& offsets_contig,
402
+ int64_t axis,
403
+ const c10::optional<Scalar>& initial) {
404
+ return _segment_reduce_lengths_offsets_backward_cuda_kernel (
405
+ grad_contig, output_contig, data_contig, reduction, offsets_contig, axis, initial, /* is_offsets_like=*/ true );
406
+ }
407
+
408
+ Tensor _segment_reduce_lengths_offsets_cuda_kernel (
409
+ SegmentReductionType reduction,
410
+ const Tensor& data,
411
+ const Tensor& lengths_or_offsets,
412
+ int64_t axis,
413
+ const c10::optional<Scalar>& initial,
414
+ bool is_offsets_like) {
415
+ // data and lengths_or_offsets should be contiguous from the call to .contiguous in segment_reduce_kernel
416
+ TORCH_CHECK (data.is_contiguous ());
417
+ TORCH_CHECK (lengths_or_offsets.is_contiguous ());
418
+ axis = lengths_or_offsets.dim () - 1 ;
419
+ int64_t segment_count = is_offsets_like ? lengths_or_offsets.size (axis) - 1 : lengths_or_offsets.size (axis);
420
+ int64_t lengths_stride_axis = lengths_or_offsets.stride (axis);
386
421
auto output_shape = data.sizes ().vec ();
387
422
output_shape[axis] = segment_count;
388
423
auto output = at::empty (output_shape, data.options ());
389
424
390
- // _get_complete_sum only supports 1D?
391
- auto zeros_shape = lengths.sizes ().vec ();
392
- zeros_shape[axis] = 1 ;
393
- auto offsets = at::cat ({at::zeros (zeros_shape, lengths.options ()), lengths}, axis);
394
- offsets.cumsum_ (axis);
425
+
426
+ auto offsets = lengths_or_offsets;
427
+ auto lengths = lengths_or_offsets;
428
+ if (is_offsets_like) {
429
+ lengths = lengths.diff ();
430
+ } else {
431
+ // _get_complete_sum only supports 1D
432
+ auto zeros_shape = offsets.sizes ().vec ();
433
+ zeros_shape[axis] = 1 ;
434
+ offsets = at::cat ({at::zeros (zeros_shape, offsets.options ()), offsets}, axis);
435
+ offsets.cumsum_ (axis);
436
+ }
395
437
396
438
// outer_offset is the size of the outer dimensions of output (before axis)
397
439
// inner_offset is the size of the inner dimensions of output (after axis)
@@ -416,7 +458,7 @@ Tensor _segment_reduce_cuda_kernel(
416
458
auto offsets_stride_axis = offsets.stride (axis);
417
459
418
460
AT_DISPATCH_INDEX_TYPES (
419
- lengths .scalar_type (), " _segment_reduce_cuda_kernel1" , ([&] {
461
+ lengths_or_offsets .scalar_type (), " _segment_reduce_cuda_kernel1" , ([&] {
420
462
auto * offsets_data_ptr = offsets.data_ptr <index_t >();
421
463
auto * lengths_data_ptr = lengths.data_ptr <index_t >();
422
464
AT_DISPATCH_FLOATING_TYPES_AND2 (
@@ -549,10 +591,34 @@ Tensor _segment_reduce_cuda_kernel(
549
591
return output;
550
592
}
551
593
552
- REGISTER_DISPATCH (_segment_reduce_stub, &_segment_reduce_cuda_kernel);
594
+ Tensor _segment_reduce_lengths_cuda_kernel (
595
+ SegmentReductionType reduction,
596
+ const Tensor& data,
597
+ const Tensor& lengths,
598
+ int64_t axis,
599
+ const c10::optional<Scalar>& initial) {
600
+ return _segment_reduce_lengths_offsets_cuda_kernel (
601
+ reduction, data, lengths, axis, initial, /* is_offsets_like=*/ false );
602
+ }
603
+
604
+ Tensor _segment_reduce_offsets_cuda_kernel (
605
+ SegmentReductionType reduction,
606
+ const Tensor& data,
607
+ const Tensor& offsets,
608
+ int64_t axis,
609
+ const c10::optional<Scalar>& initial) {
610
+ return _segment_reduce_lengths_offsets_cuda_kernel (
611
+ reduction, data, offsets, axis, initial, /* is_offsets_like=*/ true );
612
+ }
613
+
614
+ REGISTER_DISPATCH (_segment_reduce_lengths_stub, &_segment_reduce_lengths_cuda_kernel);
615
+ REGISTER_DISPATCH (_segment_reduce_offsets_stub, &_segment_reduce_offsets_cuda_kernel);
616
+ REGISTER_DISPATCH (
617
+ _segment_reduce_lengths_backward_stub,
618
+ &_segment_reduce_lengths_backward_cuda_kernel);
553
619
REGISTER_DISPATCH (
554
- _segment_reduce_backward_stub ,
555
- &_segment_reduce_cuda_backward_kernel );
620
+ _segment_reduce_offsets_backward_stub ,
621
+ &_segment_reduce_offsets_backward_cuda_kernel );
556
622
557
623
} // namespace native
558
624
} // namespace at
0 commit comments