15
15
#include < ATen/native/CompositeRandomAccessor.h>
16
16
#include < ATen/native/TopKImpl.h>
17
17
#include < c10/core/WrapDimMinimal.h>
18
- #include < c10/util/SmallBuffer.h>
19
18
#include < c10/util/irange.h>
20
-
21
19
#ifdef USE_FBGEMM
22
20
#include < fbgemm/Utils.h>
23
21
#endif
24
22
25
- #if USE_X86_SIMD_SORT && (defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2))
26
- #define XSS_COMPILE_TIME_SUPPORTED
27
- #include < src/x86simdsort-static-incl.h>
28
- #endif
29
-
30
23
namespace at ::native {
31
24
32
25
namespace {
@@ -124,7 +117,6 @@ static void parallel_sort1d_kernel(
124
117
std::vector<int64_t > tmp_vals (elements);
125
118
const scalar_t * sorted_keys = nullptr ;
126
119
const int64_t * sorted_vals = nullptr ;
127
-
128
120
std::tie (sorted_keys, sorted_vals) = fbgemm::radix_sort_parallel (
129
121
keys,
130
122
vals,
@@ -173,116 +165,6 @@ static inline void sort_kernel_impl(const value_accessor_t& value_accessor,
173
165
}
174
166
}
175
167
176
- #if defined(XSS_COMPILE_TIME_SUPPORTED)
177
-
178
- #define AT_DISPATCH_CASE_XSS_TYPES (...) \
179
- AT_DISPATCH_CASE (at::ScalarType::Long, __VA_ARGS__) \
180
- AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
181
- AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
182
- AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
183
-
184
- #define AT_DISPATCH_XSS_TYPES (TYPE, NAME, ...) \
185
- AT_DISPATCH_SWITCH (TYPE, NAME, AT_DISPATCH_CASE_XSS_TYPES(__VA_ARGS__))
186
-
187
- static bool can_use_xss_sort(const TensorBase& values, const TensorBase& indices, int64_t dim, const bool stable) {
188
- // xss_sort is not a stable sort
189
- if (stable) return false ;
190
-
191
- auto type = values.scalar_type ();
192
- if (! (type == ScalarType::Long || type == ScalarType::Int || type == ScalarType::Double || type == ScalarType::Float)) return false ;
193
-
194
- return true ;
195
- }
196
-
197
- static bool xss_sort_preferred (const TensorBase& values, const bool descending) {
198
- #if defined(XSS_USE_OPENMP) || !defined(USE_FBGEMM)
199
- return true ;
200
- #else
201
- // Without OpenMP support for x86-simd-sort, fbgemm radix sort is faster when it can be used
202
- return !can_use_radix_sort (values, descending);
203
- #endif
204
- }
205
-
206
- static void xss_sort_kernel (
207
- const TensorBase& values,
208
- const TensorBase& indices,
209
- int64_t dim,
210
- bool descending) {
211
- auto iter = TensorIteratorConfig ()
212
- .check_all_same_dtype (false )
213
- .resize_outputs (false )
214
- .declare_static_shape (values.sizes (), /* squash_dims=*/ dim)
215
- .add_output (values)
216
- .add_output (indices)
217
- .build ();
218
-
219
- using index_t = int64_t ;
220
-
221
- AT_DISPATCH_XSS_TYPES (values.scalar_type (), " xss_sort_kernel" , [&] {
222
-
223
- auto values_dim_stride = values.stride (dim);
224
- auto indices_dim_stride = indices.stride (dim);
225
- auto dim_size = values.size (dim);
226
-
227
- auto loop = [&](char ** data, const int64_t * strides, int64_t n) {
228
- auto * values_data_bytes = data[0 ];
229
- auto * indices_data_bytes = data[1 ];
230
-
231
- if (values_data_bytes==nullptr || indices_data_bytes==nullptr ){
232
- return ;
233
- }
234
-
235
- if (values_dim_stride == 1 && indices_dim_stride == 1 ){
236
- for (const auto i [[maybe_unused]] : c10::irange (n)) {
237
- x86simdsortStatic::keyvalue_qsort<scalar_t , index_t >(
238
- reinterpret_cast <scalar_t *>(values_data_bytes),
239
- reinterpret_cast <index_t *>(indices_data_bytes),
240
- dim_size,
241
- true ,
242
- descending);
243
-
244
- values_data_bytes += strides[0 ];
245
- indices_data_bytes += strides[1 ];
246
- }
247
- }else {
248
- c10::SmallBuffer<scalar_t , 0 > tmp_values (dim_size);
249
- c10::SmallBuffer<index_t , 0 > tmp_indices (dim_size);
250
-
251
- for (const auto i : c10::irange (n)) {
252
- TensorAccessor<scalar_t , 1 > mode_values_acc (
253
- reinterpret_cast <scalar_t *>(data[0 ] + i * strides[0 ]),
254
- &dim_size, &values_dim_stride);
255
- TensorAccessor<index_t , 1 > mode_indices_acc (
256
- reinterpret_cast <index_t *>(data[1 ] + i * strides[1 ]),
257
- &dim_size, &indices_dim_stride);
258
-
259
- for (const auto j : c10::irange (dim_size)) {
260
- tmp_values[j] = mode_values_acc[j];
261
- tmp_indices[j] = j;
262
- }
263
-
264
- x86simdsortStatic::keyvalue_qsort<scalar_t , index_t >(
265
- tmp_values.data (),
266
- tmp_indices.data (),
267
- dim_size,
268
- true ,
269
- descending);
270
-
271
- for (const auto j : c10::irange (dim_size)) {
272
- mode_values_acc[j] = tmp_values[j];
273
- mode_indices_acc[j] = tmp_indices[j];
274
- }
275
- }
276
- }
277
- };
278
-
279
- int64_t grain_size = internal::GRAIN_SIZE / std::max (int64_t {1 }, dim_size);
280
- iter.for_each (loop, /* grain_size=*/ grain_size);
281
-
282
- });
283
- }
284
- #endif
285
-
286
168
static void sort_kernel (
287
169
const TensorBase& self,
288
170
const TensorBase& values,
@@ -297,14 +179,6 @@ static void sort_kernel(
297
179
// https://github.com/pytorch/pytorch/issues/91420
298
180
return ;
299
181
}
300
-
301
- #if defined(XSS_COMPILE_TIME_SUPPORTED)
302
- if (can_use_xss_sort (values, indices, dim, stable) && xss_sort_preferred (values, descending)){
303
- xss_sort_kernel (values, indices, dim, descending);
304
- return ;
305
- }
306
- #endif
307
-
308
182
#ifdef USE_FBGEMM
309
183
if (can_use_radix_sort (values, descending)) {
310
184
parallel_sort1d_kernel (values, indices);
@@ -356,7 +230,6 @@ static void topk_kernel(
356
230
int64_t dim,
357
231
bool largest,
358
232
bool sorted) {
359
-
360
233
auto sizes = self.sizes ();
361
234
auto iter = TensorIteratorConfig ()
362
235
.check_all_same_dtype (false )
@@ -391,7 +264,7 @@ static void topk_kernel(
391
264
392
265
} // anonymous namespace
393
266
394
- ALSO_REGISTER_AVX512_DISPATCH (sort_stub, &sort_kernel)
395
- ALSO_REGISTER_AVX512_DISPATCH (topk_stub, &topk_kernel)
267
+ REGISTER_DISPATCH (sort_stub, &sort_kernel)
268
+ REGISTER_DISPATCH (topk_stub, &topk_kernel)
396
269
397
270
} // at::native
0 commit comments