diff --git a/Src/Base/AMReX_Scan.H b/Src/Base/AMReX_Scan.H index 11fdfd8bd70..f9fff3b753c 100644 --- a/Src/Base/AMReX_Scan.H +++ b/Src/Base/AMReX_Scan.H @@ -400,7 +400,7 @@ T PrefixSum (N n, FIN && fin, FOUT && fout, TYPE type, RetSum a_ret_sum = retSum #ifndef AMREX_SYCL_NO_MULTIPASS_SCAN if (nblocks > 1) { - return PrefixSum_mp(n, std::forward(fin), std::forward(fout), type, retSum); + return PrefixSum_mp(n, std::forward(fin), std::forward(fout), type, a_ret_sum); } #endif @@ -621,7 +621,179 @@ T PrefixSum (N n, FIN && fin, FOUT && fout, TYPE type, RetSum a_ret_sum = retSum return totalsum; } -#elif defined(AMREX_USE_HIP) +#else // #if defined(AMREX_USE_SYCL) + +#define AMREX_GPU_MULTIPASS_SCAN 1 + +#if defined(AMREX_GPU_MULTIPASS_SCAN) +template +T PrefixSum_mp (N n, FIN && fin, FOUT && fout, TYPE, RetSum a_ret_sum) +{ + if (n <= 0) { return 0; } +#if defined(AMREX_USE_HIP) + constexpr int nwarps_per_block = 4; +#else + constexpr int nwarps_per_block = 8; +#endif + constexpr int nthreads = nwarps_per_block*Gpu::Device::warp_size; + constexpr int nelms_per_thread = 12; + constexpr int nelms_per_block = nthreads * nelms_per_thread; + AMREX_ALWAYS_ASSERT(static_cast(n) < static_cast(std::numeric_limits::max())*nelms_per_block); + int nblocks = (static_cast(n) + nelms_per_block - 1) / nelms_per_block; + std::size_t sm = 0; + auto stream = Gpu::gpuStream(); + + std::size_t nbytes_blockresult = Arena::align(sizeof(T)*n); + std::size_t nbytes_blocksum = Arena::align(sizeof(T)*nblocks); + std::size_t nbytes_totalsum = Arena::align(sizeof(T)); + auto dp = (char*)(The_Arena()->alloc(nbytes_blockresult + + nbytes_blocksum + + nbytes_totalsum)); + T* blockresult_p = (T*)dp; + T* blocksum_p = (T*)(dp + nbytes_blockresult); + T* totalsum_p = (T*)(dp + nbytes_blockresult + nbytes_blocksum); + + amrex::launch(nblocks, nthreads, sm, stream, + [=] AMREX_GPU_DEVICE () noexcept + { + // Each block processes [ibegin,iend). + N ibegin = nelms_per_block * blockIdx.x; + N iend = amrex::min(static_cast(ibegin+nelms_per_block), n); + + T block_agg; + T data[nelms_per_thread]; + + constexpr bool is_exclusive = std::is_same,Type::Exclusive>::value; + +#if defined(AMREX_USE_CUDA) + + using BlockLoad = cub::BlockLoad; + using BlockScan = cub::BlockScan; + using BlockExchange = cub::BlockExchange; + + __shared__ union TempStorage + { + typename BlockLoad::TempStorage load; + typename BlockExchange::TempStorage exchange; + typename BlockScan::TempStorage scan; + } temp_storage; + + auto input_lambda = [&] (N i) -> T { return fin(i+ibegin); }; + cub::TransformInputIterator > + input_begin(cub::CountingInputIterator(0), input_lambda); + + if (static_cast(iend-ibegin) == nelms_per_block) { + BlockLoad(temp_storage.load).Load(input_begin, data); + } else { + BlockLoad(temp_storage.load).Load(input_begin, data, iend-ibegin, 0); // padding with 0 + } + + __syncthreads(); + + AMREX_IF_CONSTEXPR(is_exclusive) { + BlockScan(temp_storage.scan).ExclusiveSum(data, data, block_agg); + } else { + BlockScan(temp_storage.scan).InclusiveSum(data, data, block_agg); + } + + __syncthreads(); + + BlockExchange(temp_storage.exchange).BlockedToStriped(data); + +#else + + using BlockLoad = rocprim::block_load; + using BlockScan = rocprim::block_scan; + using BlockExchange = rocprim::block_exchange; + + __shared__ union TempStorage { + typename BlockLoad::storage_type load; + typename BlockExchange::storage_type exchange; + typename BlockScan::storage_type scan; + } temp_storage; + + auto input_begin = rocprim::make_transform_iterator( + rocprim::make_counting_iterator(N(0)), + [&] (N i) -> T { return fin(i+ibegin); }); + + if (static_cast(iend-ibegin) == nelms_per_block) { + BlockLoad().load(input_begin, data, temp_storage.load); + } else { + BlockLoad().load(input_begin, data, iend-ibegin, 0, temp_storage.load); // padding with 0 + } + + __syncthreads(); + + AMREX_IF_CONSTEXPR(is_exclusive) { + BlockScan().exclusive_scan(data, data, T{0}, block_agg, temp_storage.scan); + } else { + BlockScan().inclusive_scan(data, data, block_agg, temp_storage.scan); + } + + __syncthreads(); + + BlockExchange().blocked_to_striped(data, data, temp_storage.exchange); + +#endif + + for (int i = 0; i < nelms_per_thread; ++i) { + N offset = ibegin + i*blockDim.x + threadIdx.x; + if (offset < iend) { + blockresult_p[offset] = data[i]; + } + } + + if (threadIdx.x == 0) { + if (nblocks == 1) { + *totalsum_p = block_agg; + } + blocksum_p[blockIdx.x] = block_agg; + } + }); + + T totalsum = 0; + if (nblocks > 1) { + if constexpr (depth < 2) { + totalsum = PrefixSum_mp(nblocks, + [=] AMREX_GPU_DEVICE (int i) + { return blocksum_p[i]; }, + [=] AMREX_GPU_DEVICE (int i, const int& s) + { blocksum_p[i] = s; }, + Type::exclusive, a_ret_sum); + } else { + amrex::Abort("PrefixSum_mp: recursion is too deep"); + The_Arena()->free(dp); + return totalsum; + } + } + + amrex::launch(nblocks, nthreads, 0, stream, + [=] AMREX_GPU_DEVICE () noexcept + { + // Each block processes [ibegin,iend). + N ibegin = nelms_per_block * blockIdx.x; + N iend = amrex::min(static_cast(ibegin+nelms_per_block), n); + T prev_sum = (blockIdx.x == 0) ? 0 : blocksum_p[blockIdx.x]; + for (N offset = ibegin + threadIdx.x; offset < iend; offset += blockDim.x) { + fout(offset, prev_sum + blockresult_p[offset]); + } + }); + + if (a_ret_sum && nblocks == 1) { + Gpu::dtoh_memcpy_async(&totalsum, totalsum_p, sizeof(T)); + } + Gpu::streamSynchronize(); + The_Arena()->free(dp); + + AMREX_GPU_ERROR_CHECK(); + + return totalsum; +} +#endif // #if defined(AMREX_GPU_MULTIPASS_SCAN) + +#if defined(AMREX_USE_HIP) template ::value && @@ -634,7 +806,15 @@ T PrefixSum (N n, FIN && fin, FOUT && fout, TYPE, RetSum a_ret_sum = retSum) constexpr int nthreads = nwarps_per_block*Gpu::Device::warp_size; // # of threads per block constexpr int nelms_per_thread = sizeof(T) >= 8 ? 8 : 16; constexpr int nelms_per_block = nthreads * nelms_per_thread; - int nblocks = (n + nelms_per_block - 1) / nelms_per_block; + AMREX_ALWAYS_ASSERT(static_cast(n) < static_cast(std::numeric_limits::max())*nelms_per_block); + int nblocks = (Long(n) + nelms_per_block - 1) / nelms_per_block; + +#if defined(AMREX_GPU_MULTIPASS_SCAN) + if (nblocks > 1) { + return PrefixSum_mp<0,T>(n, std::forward(fin), std::forward(fout), TYPE{}, a_ret_sum); + } +#endif + std::size_t sm = 0; auto stream = Gpu::gpuStream(); @@ -791,6 +971,12 @@ T PrefixSum (N n, FIN && fin, FOUT && fout, TYPE, RetSum a_ret_sum = retSum) ScanTileState tile_state; tile_state.Init(nblocks, tile_state_p, tile_state_size); // Init ScanTileState on host +#if defined(AMREX_GPU_MULTIPASS_SCAN) + if (nblocks > 1) { + return PrefixSum_mp<0,T>(n, std::forward(fin), std::forward(fout), TYPE{}, a_ret_sum); + } +#endif + if (nblocks > 1) { // Init ScanTileState on device amrex::launch((nblocks+nthreads-1)/nthreads, nthreads, 0, stream, [=] AMREX_GPU_DEVICE () @@ -912,6 +1098,13 @@ T PrefixSum (N n, FIN && fin, FOUT && fout, TYPE, RetSum a_ret_sum = retSum) constexpr int nelms_per_block = nthreads * nchunks; AMREX_ALWAYS_ASSERT(static_cast(n) < static_cast(std::numeric_limits::max())*nelms_per_block); int nblocks = (static_cast(n) + nelms_per_block - 1) / nelms_per_block; + +#if defined(AMREX_GPU_MULTIPASS_SCAN) + if (nblocks > 1) { + return PrefixSum_mp<0,T>(n, std::forward(fin), std::forward(fout), TYPE{}, a_ret_sum); + } +#endif + std::size_t sm = sizeof(T) * (Gpu::Device::warp_size + nwarps_per_block) + sizeof(int); auto stream = Gpu::gpuStream(); @@ -1141,7 +1334,9 @@ T PrefixSum (N n, FIN && fin, FOUT && fout, TYPE, RetSum a_ret_sum = retSum) return totalsum; } -#endif +#endif // #if defined(AMREX_USE_HIP) + +#endif // #if defined(AMREX_USE_SYCL) // The return value is the total sum if a_ret_sum is true. template ::value> >