Skip to content

Commit 4bf8ffa

Browse files
author
David Reveman
committed
build: use breeze BlockRadixSort instead of cub::BlockRadixSort
Also switches to using warp-striped arrangement as more efficient for radix sort and required by breeze.
1 parent 99e1958 commit 4bf8ffa

File tree

2 files changed

+33
-22
lines changed

2 files changed

+33
-22
lines changed

velox/experimental/wave/common/Block.cuh

+29-14
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818

1919
#include <breeze/functions/reduce.h>
2020
#include <breeze/functions/scan.h>
21+
#include <breeze/functions/sort.h>
2122
#include <breeze/functions/store.h>
2223
#include <breeze/platforms/platform.h>
2324
#include <breeze/utils/types.h>
2425
#include <breeze/platforms/cuda.cuh>
25-
#include <cub/block/block_radix_sort.cuh>
2626
#include "velox/experimental/wave/common/CudaUtil.cuh"
2727

2828
/// Utilities for booleans and indices and thread blocks.
@@ -137,8 +137,12 @@ template <
137137
int32_t kItemsPerThread,
138138
typename Key,
139139
typename Value>
140-
using RadixSort =
141-
typename cub::BlockRadixSort<Key, kBlockSize, kItemsPerThread, Value>;
140+
using RadixSort = typename breeze::functions::BlockRadixSort<
141+
CudaPlatform<kBlockSize, kWarpThreads>,
142+
kItemsPerThread,
143+
/*RADIX_BITS=*/8,
144+
Key,
145+
Value>;
142146

143147
template <
144148
int32_t kBlockSize,
@@ -147,7 +151,7 @@ template <
147151
typename Value>
148152
inline int32_t __host__ __device__ blockSortSharedSize() {
149153
return sizeof(
150-
typename RadixSort<kBlockSize, kItemsPerThread, Key, Value>::TempStorage);
154+
typename RadixSort<kBlockSize, kItemsPerThread, Key, Value>::Scratch);
151155
}
152156

153157
template <
@@ -165,7 +169,9 @@ void __device__ blockSort(
165169
char* smem) {
166170
using namespace breeze::functions;
167171
using namespace breeze::utils;
168-
using Sort = cub::BlockRadixSort<Key, kBlockSize, kItemsPerThread, Value>;
172+
173+
CudaPlatform<kBlockSize, kWarpThreads> p;
174+
using RadixSortT = RadixSort<kBlockSize, kItemsPerThread, Key, Value>;
169175

170176
// Per-thread tile items
171177
Key keys[kItemsPerThread];
@@ -174,28 +180,37 @@ void __device__ blockSort(
174180
// Our current block's offset
175181
int blockOffset = 0;
176182

177-
// Load items into a blocked arrangement
183+
constexpr int32_t kWarpItems = kWarpThreads * kItemsPerThread;
184+
static_assert(
185+
(kBlockSize % kWarpThreads) == 0,
186+
"kBlockSize must be a multiple of kWarpThreads");
187+
188+
// Load items into a warp-striped arrangement
189+
int32_t threadOffset = p.warp_idx() * kWarpItems + p.lane_idx();
178190
for (auto i = 0; i < kItemsPerThread; ++i) {
179-
int32_t idx = blockOffset + i * kBlockSize + threadIdx.x;
191+
int32_t idx = blockOffset + threadOffset + i * kWarpThreads;
180192
values[i] = valueGetter(idx);
181193
keys[i] = keyGetter(idx);
182194
}
183195

184196
__syncthreads();
185-
auto* temp_storage = reinterpret_cast<typename Sort::TempStorage*>(smem);
197+
auto* temp_storage = reinterpret_cast<typename RadixSortT::Scratch*>(smem);
186198

187-
Sort(*temp_storage).SortBlockedToStriped(keys, values);
199+
RadixSortT::Sort(
200+
p,
201+
make_slice<THREAD, WARP_STRIPED>(keys),
202+
make_slice<THREAD, WARP_STRIPED>(values),
203+
make_slice(temp_storage).template reinterpret<SHARED>());
188204

189-
// Store a striped arrangement of output across the thread block into a linear
190-
// segment of items
191-
CudaPlatform<kBlockSize, kWarpThreads> p;
205+
// Store a warp-striped arrangement of output across the thread block into a
206+
// linear segment of items
192207
BlockStore<kBlockSize, kItemsPerThread>(
193208
p,
194-
make_slice<THREAD, STRIPED>(values),
209+
make_slice<THREAD, WARP_STRIPED>(values),
195210
make_slice<GLOBAL>(valueOut + blockOffset));
196211
BlockStore<kBlockSize, kItemsPerThread>(
197212
p,
198-
make_slice<THREAD, STRIPED>(keys),
213+
make_slice<THREAD, WARP_STRIPED>(keys),
199214
make_slice<GLOBAL>(keyOut + blockOffset));
200215
__syncthreads();
201216
}

velox/experimental/wave/common/tests/BlockTest.cu

+4-8
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,8 @@ void __global__ __launch_bounds__(1024)
187187
testSortNoShared(uint16_t** keys, uint16_t** values, char* smem) {
188188
auto keyBase = keys[blockIdx.x];
189189
auto valueBase = values[blockIdx.x];
190-
char* tbTemp = smem +
191-
blockIdx.x *
192-
sizeof(typename cub::BlockRadixSort<uint16_t, 256, 32, uint16_t>::
193-
TempStorage);
190+
char* tbTemp =
191+
smem + blockIdx.x * blockSortSharedSize<256, 32, uint16_t, uint16_t>();
194192

195193
blockSort<256, 32>(
196194
[&](auto i) { return keyBase[i]; },
@@ -202,16 +200,14 @@ void __global__ __launch_bounds__(1024)
202200
}
203201

204202
int32_t BlockTestStream::sort16SharedSize() {
205-
return sizeof(
206-
typename cub::BlockRadixSort<uint16_t, 256, 32, uint16_t>::TempStorage);
203+
return blockSortSharedSize<256, 32, uint16_t, uint16_t>();
207204
}
208205

209206
void BlockTestStream::testSort16(
210207
int32_t numBlocks,
211208
uint16_t** keys,
212209
uint16_t** values) {
213-
auto tempBytes = sizeof(
214-
typename cub::BlockRadixSort<uint16_t, 256, 32, uint16_t>::TempStorage);
210+
auto tempBytes = blockSortSharedSize<256, 32, uint16_t, uint16_t>();
215211

216212
testSort<<<numBlocks, 256, tempBytes, stream_->stream>>>(keys, values);
217213
}

0 commit comments

Comments
 (0)