Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

build: Remove usage of cub from Block.cuh and related test code #12545

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
22 changes: 22 additions & 0 deletions velox/experimental/breeze/breeze/functions/sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,28 @@ struct RadixSortTraits {
static ATTR T from_bit_ordered(T value);
};

// specialization for T=short
template <>
struct RadixSortTraits<short> {
static ATTR short to_bit_ordered(short value) {
return value ^ (1 << utils::Msb<short>::VALUE);
}
static ATTR short from_bit_ordered(short value) {
return value ^ (1 << utils::Msb<short>::VALUE);
}
};

// specialization for T=unsigned short
template <>
struct RadixSortTraits<unsigned short> {
static ATTR unsigned short to_bit_ordered(unsigned short value) {
return value;
}
static ATTR unsigned short from_bit_ordered(unsigned short value) {
return value;
}
};

// specialization for T=int
template <>
struct RadixSortTraits<int> {
Expand Down
16 changes: 16 additions & 0 deletions velox/experimental/wave/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Use breeze PTX specialization by default for CUDA.
if(NOT DEFINED CUDA_PLATFORM_SPECIALIZATION_HEADER)
set(CUDA_PLATFORM_SPECIALIZATION_HEADER
breeze/platforms/specialization/cuda-ptx.cuh
CACHE STRING "CUDA platform specialization header")
endif()

# Add header only library for breeze and CUDA platform.
add_library(breeze_cuda INTERFACE)
target_include_directories(breeze_cuda INTERFACE ../breeze)
target_compile_definitions(
breeze_cuda
INTERFACE
PLATFORM_CUDA
CUDA_PLATFORM_SPECIALIZATION_HEADER=${CUDA_PLATFORM_SPECIALIZATION_HEADER})

add_subdirectory(common)
add_subdirectory(exec)
add_subdirectory(vector)
Expand Down
177 changes: 105 additions & 72 deletions velox/experimental/wave/common/Block.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

#pragma once

#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_store.cuh>
#include <breeze/functions/reduce.h>
#include <breeze/functions/scan.h>
#include <breeze/functions/sort.h>
#include <breeze/functions/store.h>
#include <breeze/platforms/platform.h>
#include <breeze/utils/types.h>
#include <breeze/platforms/cuda.cuh>
#include "velox/experimental/wave/common/CudaUtil.cuh"

/// Utilities for booleans and indices and thread blocks.
Expand All @@ -29,47 +32,51 @@ namespace facebook::velox::wave {
/// Converts an array of flags to an array of indices of set flags. The first
/// index is given by 'start'. The number of indices is returned in 'size', i.e.
/// this is 1 + the index of the last set flag.
template <
typename T,
int32_t blockSize,
cub::BlockScanAlgorithm Algorithm = cub::BLOCK_SCAN_RAKING>
template <typename T, int32_t blockSize>
inline int32_t __device__ __host__ boolToIndicesSharedSize() {
typedef cub::BlockScan<T, blockSize, Algorithm> BlockScanT;
using namespace breeze::functions;

using PlatformT = CudaPlatform<blockSize, kWarpThreads>;
using BlockScanT = BlockScan<PlatformT, T, /*kItemsPerThread=*/1>;

return sizeof(typename BlockScanT::TempStorage);
return sizeof(typename BlockScanT::Scratch);
}

/// Converts an array of flags to an array of indices of set flags. The first
/// index is given by 'start'. The number of indices is returned in 'size', i.e.
/// this is 1 + the index of the last set flag.
template <
int32_t blockSize,
typename T,
cub::BlockScanAlgorithm Algorithm = cub::BLOCK_SCAN_RAKING,
typename Getter>
template <int32_t blockSize, typename T, typename Getter>
__device__ inline void
boolBlockToIndices(Getter getter, T start, T* indices, void* shmem, T& size) {
typedef cub::BlockScan<T, blockSize, Algorithm> BlockScanT;
using namespace breeze::functions;
using namespace breeze::utils;

auto* temp = reinterpret_cast<typename BlockScanT::TempStorage*>(shmem);
CudaPlatform<blockSize, kWarpThreads> p;
using BlockScanT = BlockScan<decltype(p), T, /*kItemsPerThread=*/1>;

auto* temp = reinterpret_cast<typename BlockScanT::Scratch*>(shmem);
T data[1];
uint8_t flag = getter();
data[0] = flag;
__syncthreads();
T aggregate;
BlockScanT(*temp).ExclusiveSum(data, data, aggregate);
// Perform inclusive scan
T aggregate = BlockScanT::template Scan<ScanOpAdd>(
p,
make_slice(data),
make_slice(data),
make_slice(temp).template reinterpret<SHARED>());
if (flag) {
indices[data[0]] = threadIdx.x + start;
T exclusive_result = data[0] - flag;
indices[exclusive_result] = threadIdx.x + start;
}
if (threadIdx.x == 0) {
if (threadIdx.x == (blockSize - 1)) {
size = aggregate;
}
__syncthreads();
}

inline int32_t __device__ __host__ bool256ToIndicesSize() {
return sizeof(typename cub::WarpScan<uint16_t>::TempStorage) +
33 * sizeof(uint16_t);
return 33 * sizeof(uint16_t);
}

/// Returns indices of set bits for 256 one byte flags. 'getter8' is
Expand All @@ -80,7 +87,7 @@ inline int32_t __device__ __host__ bool256ToIndicesSize() {
template <typename T, typename Getter8>
__device__ inline void
bool256ToIndices(Getter8 getter8, T start, T* indices, T& size, char* smem) {
using Scan = cub::WarpScan<uint16_t>;
CudaPlatform<kWarpThreads, kWarpThreads> p;
auto* smem16 = reinterpret_cast<uint16_t*>(smem);
int32_t group = threadIdx.x / 8;
uint64_t bits = getter8(group) & 0x0101010101010101;
Expand All @@ -89,10 +96,8 @@ bool256ToIndices(Getter8 getter8, T start, T* indices, T& size, char* smem) {
}
__syncthreads();
if (threadIdx.x < 32) {
auto* temp = reinterpret_cast<typename Scan::TempStorage*>((smem + 72));
uint16_t data = smem16[threadIdx.x];
uint16_t result;
Scan(*temp).ExclusiveSum(data, result);
uint16_t result = p.scan_add(data) - data;
smem16[threadIdx.x] = result;
if (threadIdx.x == 31) {
size = data + result;
Expand All @@ -110,15 +115,20 @@ bool256ToIndices(Getter8 getter8, T start, T* indices, T& size, char* smem) {

template <int32_t blockSize, typename T, typename Getter>
__device__ inline void blockSum(Getter getter, void* shmem, T* result) {
typedef cub::BlockReduce<T, blockSize> BlockReduceT;
using namespace breeze::functions;
using namespace breeze::utils;

CudaPlatform<blockSize, kWarpThreads> p;
using BlockReduceT = BlockReduce<decltype(p), T>;

auto* temp = reinterpret_cast<typename BlockReduceT::TempStorage*>(shmem);
auto* temp = reinterpret_cast<typename BlockReduceT::Scratch*>(shmem);
T data[1];
data[0] = getter();
T aggregate = BlockReduceT(*temp).Reduce(data, cub::Sum());

if (threadIdx.x == 0) {
result[blockIdx.x] = aggregate;
T aggregate =
BlockReduceT::template Reduce<ReduceOpAdd, /*kItemsPerThread=*/1>(
p, make_slice(data), make_slice(temp).template reinterpret<SHARED>());
if (p.thread_idx() == 0) {
result[p.block_idx()] = aggregate;
}
}

Expand All @@ -127,8 +137,12 @@ template <
int32_t kItemsPerThread,
typename Key,
typename Value>
using RadixSort =
typename cub::BlockRadixSort<Key, kBlockSize, kItemsPerThread, Value>;
using RadixSort = typename breeze::functions::BlockRadixSort<
CudaPlatform<kBlockSize, kWarpThreads>,
kItemsPerThread,
/*RADIX_BITS=*/8,
Key,
Value>;

template <
int32_t kBlockSize,
Expand All @@ -137,7 +151,7 @@ template <
typename Value>
inline int32_t __host__ __device__ blockSortSharedSize() {
return sizeof(
typename RadixSort<kBlockSize, kItemsPerThread, Key, Value>::TempStorage);
typename RadixSort<kBlockSize, kItemsPerThread, Key, Value>::Scratch);
}

template <
Expand All @@ -153,7 +167,11 @@ void __device__ blockSort(
Key* keyOut,
Value* valueOut,
char* smem) {
using Sort = cub::BlockRadixSort<Key, kBlockSize, kItemsPerThread, Value>;
using namespace breeze::functions;
using namespace breeze::utils;

CudaPlatform<kBlockSize, kWarpThreads> p;
using RadixSortT = RadixSort<kBlockSize, kItemsPerThread, Key, Value>;

// Per-thread tile items
Key keys[kItemsPerThread];
Expand All @@ -162,35 +180,53 @@ void __device__ blockSort(
// Our current block's offset
int blockOffset = 0;

// Load items into a blocked arrangement
constexpr int32_t kWarpItems = kWarpThreads * kItemsPerThread;
static_assert(
(kBlockSize % kWarpThreads) == 0,
"kBlockSize must be a multiple of kWarpThreads");

// Load items into a warp-striped arrangement
int32_t threadOffset = p.warp_idx() * kWarpItems + p.lane_idx();
for (auto i = 0; i < kItemsPerThread; ++i) {
int32_t idx = blockOffset + i * kBlockSize + threadIdx.x;
int32_t idx = blockOffset + threadOffset + i * kWarpThreads;
values[i] = valueGetter(idx);
keys[i] = keyGetter(idx);
}

__syncthreads();
auto* temp_storage = reinterpret_cast<typename Sort::TempStorage*>(smem);
auto* temp_storage = reinterpret_cast<typename RadixSortT::Scratch*>(smem);

Sort(*temp_storage).SortBlockedToStriped(keys, values);
RadixSortT::Sort(
p,
make_slice<THREAD, WARP_STRIPED>(keys),
make_slice<THREAD, WARP_STRIPED>(values),
make_slice(temp_storage).template reinterpret<SHARED>());

// Store output in striped fashion
cub::StoreDirectStriped<kBlockSize>(
threadIdx.x, valueOut + blockOffset, values);
cub::StoreDirectStriped<kBlockSize>(threadIdx.x, keyOut + blockOffset, keys);
// Store a warp-striped arrangement of output across the thread block into a
// linear segment of items
BlockStore<kBlockSize, kItemsPerThread>(
p,
make_slice<THREAD, WARP_STRIPED>(values),
make_slice<GLOBAL>(valueOut + blockOffset));
BlockStore<kBlockSize, kItemsPerThread>(
p,
make_slice<THREAD, WARP_STRIPED>(keys),
make_slice<GLOBAL>(keyOut + blockOffset));
__syncthreads();
}

template <int kBlockSize>
int32_t partitionRowsSharedSize(int32_t numPartitions) {
using Scan = cub::BlockScan<int, kBlockSize>;
auto scanSize = sizeof(typename Scan::TempStorage) + sizeof(int32_t);
using namespace breeze::functions;
using PlatformT = CudaPlatform<kBlockSize, kWarpThreads>;
using BlockScanT = BlockScan<PlatformT, int32_t, /*kItemsPerThread=*/1>;
auto scanSize =
max(sizeof(typename BlockScanT::Scratch), sizeof(int32_t) * kBlockSize) +
sizeof(int32_t);
int32_t counterSize = sizeof(int32_t) * numPartitions;
if (counterSize <= scanSize) {
return scanSize;
}
static_assert(
sizeof(typename Scan::TempStorage) >= sizeof(int32_t) * kBlockSize);
return scanSize + counterSize; // - kBlockSize * sizeof(int32_t);
}

Expand All @@ -211,16 +247,19 @@ void __device__ partitionRows(
RowNumber* ranks,
RowNumber* partitionStarts,
RowNumber* partitionedRows) {
using Scan = cub::BlockScan<int32_t, kBlockSize>;
constexpr int32_t kWarpThreads = 1 << CUB_LOG_WARP_THREADS(0);
using namespace breeze::functions;
using namespace breeze::utils;

CudaPlatform<kBlockSize, kWarpThreads> p;
using BlockScanT = BlockScan<decltype(p), int32_t, /*kItemsPerThread=*/1>;
auto warp = threadIdx.x / kWarpThreads;
auto lane = cub::LaneId();
auto lane = threadIdx.x % kWarpThreads;
extern __shared__ __align__(16) char smem[];
auto* counters = reinterpret_cast<uint32_t*>(
numPartitions <= kBlockSize ? smem
: smem +
sizeof(typename Scan::
TempStorage) /*- kBlockSize * sizeof(uint32_t)*/);
sizeof(typename BlockScanT::
Scratch) /*- kBlockSize * sizeof(uint32_t)*/);
for (auto i = threadIdx.x; i < numPartitions; i += kBlockSize) {
counters[i] = 0;
}
Expand Down Expand Up @@ -248,7 +287,7 @@ void __device__ partitionRows(
}
// Prefix sum the counts. All counters must have their final value.
__syncthreads();
auto* temp = reinterpret_cast<typename Scan::TempStorage*>(smem);
auto* temp = reinterpret_cast<typename BlockScanT::Scratch*>(smem);
int32_t* aggregate = reinterpret_cast<int32_t*>(smem);
for (auto start = 0; start < numPartitions; start += kBlockSize) {
int32_t localCount[1];
Expand All @@ -258,7 +297,11 @@ void __device__ partitionRows(
// The sum of the previous round is carried over as start of this.
localCount[0] += *aggregate;
}
Scan(*temp).InclusiveSum(localCount, localCount);
BlockScanT::template Scan<ScanOpAdd>(
p,
make_slice(localCount),
make_slice(localCount),
make_slice(temp).template reinterpret<SHARED>());
if (start + threadIdx.x < numPartitions) {
partitionStarts[start + threadIdx.x] = localCount[0];
}
Expand Down Expand Up @@ -289,10 +332,8 @@ void __device__ partitionRows(
template <typename T, int32_t kBlockSize>
inline __device__ T exclusiveSum(T input, T* total, T* temp) {
constexpr int32_t kNumWarps = kBlockSize / kWarpThreads;
using Scan = cub::WarpScan<T>;
T sum;
Scan(*reinterpret_cast<typename Scan::TempStorage*>(temp))
.ExclusiveSum(input, sum);
CudaPlatform<kBlockSize, kWarpThreads> p;
T sum = p.scan_add(input) - input;
if (kBlockSize == kWarpThreads) {
if (total) {
if (threadIdx.x == kWarpThreads - 1) {
Expand All @@ -306,11 +347,8 @@ inline __device__ T exclusiveSum(T input, T* total, T* temp) {
temp[threadIdx.x / kWarpThreads] = input + sum;
}
__syncthreads();
using InnerScan = cub::WarpScan<T, kNumWarps>;
T warpSum = threadIdx.x < kNumWarps ? temp[threadIdx.x] : 0;
T blockSum;
InnerScan(*reinterpret_cast<typename InnerScan::TempStorage*>(temp))
.ExclusiveSum(warpSum, blockSum);
T blockSum = p.scan_add(warpSum) - warpSum;
if (threadIdx.x < kNumWarps) {
temp[threadIdx.x] = blockSum;
if (total && threadIdx.x == kNumWarps - 1) {
Expand All @@ -328,10 +366,8 @@ inline __device__ T exclusiveSum(T input, T* total, T* temp) {
template <typename T, int32_t kBlockSize>
inline __device__ T inclusiveSum(T input, T* total, T* temp) {
constexpr int32_t kNumWarps = kBlockSize / kWarpThreads;
using Scan = cub::WarpScan<T>;
T sum;
Scan(*reinterpret_cast<typename Scan::TempStorage*>(temp))
.InclusiveSum(input, sum);
CudaPlatform<kBlockSize, kWarpThreads> p;
T sum = p.scan_add(input);
if (kBlockSize <= kWarpThreads) {
if (total != nullptr) {
if (threadIdx.x == kBlockSize - 1) {
Expand All @@ -346,11 +382,8 @@ inline __device__ T inclusiveSum(T input, T* total, T* temp) {
}
__syncthreads();
constexpr int32_t kInnerWidth = kNumWarps < 2 ? 2 : kNumWarps;
using InnerScan = cub::WarpScan<T, kInnerWidth>;
T warpSum = threadIdx.x < kInnerWidth ? temp[threadIdx.x] : 0;
T blockSum;
InnerScan(*reinterpret_cast<typename InnerScan::TempStorage*>(temp))
.ExclusiveSum(warpSum, blockSum);
T blockSum = p.scan_add(warpSum) - warpSum;
if (threadIdx.x < kInnerWidth) {
temp[threadIdx.x] = blockSum;
}
Expand Down
Loading
Loading