18
18
19
19
#include < breeze/functions/reduce.h>
20
20
#include < breeze/functions/scan.h>
21
+ #include < breeze/functions/sort.h>
21
22
#include < breeze/functions/store.h>
22
23
#include < breeze/platforms/platform.h>
23
24
#include < breeze/utils/types.h>
24
25
#include < breeze/platforms/cuda.cuh>
25
- #include < cub/block/block_radix_sort.cuh>
26
26
#include " velox/experimental/wave/common/CudaUtil.cuh"
27
27
28
28
// / Utilities for booleans and indices and thread blocks.
@@ -137,8 +137,12 @@ template <
137
137
int32_t kItemsPerThread ,
138
138
typename Key,
139
139
typename Value>
140
- using RadixSort =
141
- typename cub::BlockRadixSort<Key, kBlockSize , kItemsPerThread , Value>;
140
+ using RadixSort = typename breeze::functions::BlockRadixSort<
141
+ CudaPlatform<kBlockSize , kWarpThreads >,
142
+ kItemsPerThread ,
143
+ /* RADIX_BITS=*/ 8 ,
144
+ Key,
145
+ Value>;
142
146
143
147
template <
144
148
int32_t kBlockSize ,
@@ -147,7 +151,7 @@ template <
147
151
typename Value>
148
152
inline int32_t __host__ __device__ blockSortSharedSize () {
149
153
return sizeof (
150
- typename RadixSort<kBlockSize , kItemsPerThread , Key, Value>::TempStorage );
154
+ typename RadixSort<kBlockSize , kItemsPerThread , Key, Value>::Scratch );
151
155
}
152
156
153
157
template <
@@ -165,7 +169,9 @@ void __device__ blockSort(
165
169
char * smem) {
166
170
using namespace breeze ::functions;
167
171
using namespace breeze ::utils;
168
- using Sort = cub::BlockRadixSort<Key, kBlockSize , kItemsPerThread , Value>;
172
+
173
+ CudaPlatform<kBlockSize , kWarpThreads > p;
174
+ using RadixSortT = RadixSort<kBlockSize , kItemsPerThread , Key, Value>;
169
175
170
176
// Per-thread tile items
171
177
Key keys[kItemsPerThread ];
@@ -174,28 +180,37 @@ void __device__ blockSort(
174
180
// Our current block's offset
175
181
int blockOffset = 0 ;
176
182
177
- // Load items into a blocked arrangement
183
+ constexpr int32_t kWarpItems = kWarpThreads * kItemsPerThread ;
184
+ static_assert (
185
+ (kBlockSize % kWarpThreads ) == 0 ,
186
+ " kBlockSize must be a multiple of kWarpThreads" );
187
+
188
+ // Load items into a warp-striped arrangement
189
+ int32_t threadOffset = p.warp_idx () * kWarpItems + p.lane_idx ();
178
190
for (auto i = 0 ; i < kItemsPerThread ; ++i) {
179
- int32_t idx = blockOffset + i * kBlockSize + threadIdx . x ;
191
+ int32_t idx = blockOffset + threadOffset + i * kWarpThreads ;
180
192
values[i] = valueGetter (idx);
181
193
keys[i] = keyGetter (idx);
182
194
}
183
195
184
196
__syncthreads ();
185
- auto * temp_storage = reinterpret_cast <typename Sort::TempStorage *>(smem);
197
+ auto * temp_storage = reinterpret_cast <typename RadixSortT::Scratch *>(smem);
186
198
187
- Sort (*temp_storage).SortBlockedToStriped (keys, values);
199
+ RadixSortT::Sort (
200
+ p,
201
+ make_slice<THREAD, WARP_STRIPED>(keys),
202
+ make_slice<THREAD, WARP_STRIPED>(values),
203
+ make_slice (temp_storage).template reinterpret <SHARED>());
188
204
189
- // Store a striped arrangement of output across the thread block into a linear
190
- // segment of items
191
- CudaPlatform<kBlockSize , kWarpThreads > p;
205
+ // Store a warp-striped arrangement of output across the thread block into a
206
+ // linear segment of items
192
207
BlockStore<kBlockSize , kItemsPerThread >(
193
208
p,
194
- make_slice<THREAD, STRIPED >(values),
209
+ make_slice<THREAD, WARP_STRIPED >(values),
195
210
make_slice<GLOBAL>(valueOut + blockOffset));
196
211
BlockStore<kBlockSize , kItemsPerThread >(
197
212
p,
198
- make_slice<THREAD, STRIPED >(keys),
213
+ make_slice<THREAD, WARP_STRIPED >(keys),
199
214
make_slice<GLOBAL>(keyOut + blockOffset));
200
215
__syncthreads ();
201
216
}
0 commit comments