7
7
#include < ATen/core/op_registration/adaption.h>
8
8
#include < ATen/native/DispatchStub.h>
9
9
#include < ATen/native/IndexKernel.h>
10
+ #include < ATen/native/ReductionType.h>
10
11
#include < ATen/native/TensorAdvancedIndexing.h>
11
12
#include < ATen/native/TensorAdvancedIndexingUtils.h>
12
13
#include < ATen/native/TensorIterator.h>
14
+ // #include <ATen/native/TensorFactories.cpp>
13
15
#include < ATen/native/xpu/sycl/IndexingKernels.h>
14
16
#include < ATen/native/xpu/sycl/ScatterGatherKernels.h>
17
+ #include < ATen/ops/ones_like.h>
18
+ #include < ATen/ops/zeros_like.h>
15
19
#include < comm/RegisterUtils.h>
16
20
#include < comm/xpu_aten.h>
17
21
#include < torch/library.h>
18
22
19
23
#include < ATen/ops/index_add_meta.h>
24
+ #include < ATen/ops/index_reduce_meta.h>
20
25
#include < xpu/ATen/ops/index_add_native.h>
26
+ #include < xpu/ATen/ops/index_reduce_native.h> // generated
27
+ // #include <xpu/ATen/ops/index_reduce_prod_native.h> //generated
21
28
22
29
namespace at {
23
30
@@ -42,6 +49,7 @@ REGISTER_XPU_DISPATCH(index_fill_stub, &xpu::index_fill_kernel);
42
49
REGISTER_XPU_DISPATCH (index_copy_stub, &xpu::index_copy_kernel);
43
50
REGISTER_XPU_DISPATCH (put_stub, &xpu::put_kernel);
44
51
REGISTER_XPU_DISPATCH (take_stub, &xpu::take_kernel);
52
+ // REGISTER_XPU_DISPATCH(index_reduce_stub, &xpu::index_reduce_kernel);
45
53
46
54
TORCH_IMPL_FUNC (index_add_xpu_out)
47
55
(const Tensor& self,
@@ -126,5 +134,44 @@ Tensor count_nonzero_xpu(const Tensor& self, IntArrayRef dims) {
126
134
return (self != 0 ).sum (dims);
127
135
}
128
136
137
+ TORCH_IMPL_FUNC (index_reduce_xpu_out)
138
+ (const Tensor& self,
139
+ int64_t dim,
140
+ const Tensor& index,
141
+ const Tensor& source,
142
+ const c10::string_view reduce,
143
+ bool include_self,
144
+ const Tensor& result) {
145
+ TORCH_WARN_ONCE (
146
+ " index_reduce() is in beta and the API may change at any time." );
147
+ if (reduce == " prod" ) {
148
+ xpu::index_reduce_prod_kernel (
149
+ self, dim, index , source, include_self, ReductionType::PROD, result);
150
+ } else if (reduce == " mean" ) {
151
+ xpu::index_reduce_mean_kernel (
152
+ self, dim, index , source, include_self, ReductionType::MEAN, result);
153
+ auto counts = include_self ? ones_like (result) : zeros_like (result);
154
+ counts.index_add_ (dim, index , ones_like (source));
155
+ counts.masked_fill_ (counts == 0 , 1 );
156
+ if (result.is_floating_point () || result.is_complex ()) {
157
+ result.div_ (counts);
158
+ } else {
159
+ result.div_ (counts, " floor" );
160
+ }
161
+ } else if (reduce == " amax" ) {
162
+ xpu::index_reduce_amax_kernel (
163
+ self, dim, index , source, include_self, ReductionType::MAX, result);
164
+ } else if (reduce == " amin" ) {
165
+ xpu::index_reduce_amin_kernel (
166
+ self, dim, index , source, include_self, ReductionType::MIN, result);
167
+ } else {
168
+ TORCH_CHECK (
169
+ false ,
170
+ " Only support prod, mean, amax or amin reduce operator. Input was " ,
171
+ reduce,
172
+ " ." );
173
+ }
174
+ }
175
+
129
176
} // namespace native
130
177
} // namespace at
0 commit comments