Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a way to reuse TagParallelFor Tags #4348

Open
wants to merge 1 commit into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Src/Base/AMReX_FBI.H
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ void
fab_to_fab (Vector<Array4CopyTag<T0, T1> > const& copy_tags, int scomp, int dcomp, int ncomp,
F && f)
{
detail::ParallelFor_doit(copy_tags,
TagVector<Array4CopyTag<T0, T1>> tv{copy_tags};

detail::ParallelFor_doit(tv,
[=] AMREX_GPU_DEVICE (
#ifdef AMREX_USE_SYCL
sycl::nd_item<1> const& /*item*/,
Expand Down Expand Up @@ -85,7 +87,9 @@ fab_to_fab (Vector<Array4CopyTag<T0, T1> > const& copy_tags, int scomp, int dcom

amrex::Abort("xxxxx TODO This function still has a bug. Even if we fix the bug, it should still be avoided because it is slow due to the lack of atomic operations for this type.");

detail::ParallelFor_doit(tags,
TagVector<TagType> tv{tags};

detail::ParallelFor_doit(tv,
[=] AMREX_GPU_DEVICE (
#ifdef AMREX_USE_SYCL
sycl::nd_item<1> const& item,
Expand Down
297 changes: 226 additions & 71 deletions Src/Base/AMReX_TagParallelFor.H
Original file line number Diff line number Diff line change
Expand Up @@ -101,31 +101,153 @@ struct VectorTag {
Long size () const noexcept { return m_size; }
};

#ifdef AMREX_USE_GPU

namespace detail {

template <typename T>
std::enable_if_t<std::is_same<std::decay_t<decltype(std::declval<T>().box())>, Box>::value,
Long>
get_tag_size (T const& tag) noexcept
{
AMREX_ASSERT(tag.box().numPts() < Long(std::numeric_limits<int>::max()));
return static_cast<int>(tag.box().numPts());
}
template <typename T>
std::enable_if_t<std::is_same_v<std::decay_t<decltype(std::declval<T>().box())>, Box>, Long>
get_tag_size (T const& tag) noexcept
{
AMREX_ASSERT(tag.box().numPts() < Long(std::numeric_limits<int>::max()));
return static_cast<int>(tag.box().numPts());
}

template <typename T>
std::enable_if_t<std::is_integral_v<std::decay_t<decltype(std::declval<T>().size())> >, Long>
get_tag_size (T const& tag) noexcept
{
AMREX_ASSERT(tag.size() < Long(std::numeric_limits<int>::max()));
return tag.size();
}

template <typename T>
constexpr
std::enable_if_t<std::is_same_v<std::decay_t<decltype(std::declval<T>().box())>, Box>, bool>
is_box_tag (T const&) { return true; }

template <typename T>
constexpr
std::enable_if_t<std::is_integral_v<std::decay_t<decltype(std::declval<T>().size())> >, bool>
is_box_tag (T const&) { return false; }

template <typename T>
std::enable_if_t<std::is_integral<std::decay_t<decltype(std::declval<T>().size())> >::value,
Long>
get_tag_size (T const& tag) noexcept
{
AMREX_ASSERT(tag.size() < Long(std::numeric_limits<int>::max()));
return tag.size();
}

template <class TagType>
struct TagVector {

char* h_buffer = nullptr;
char* d_buffer = nullptr;
TagType* d_tags = nullptr;
int* d_nwarps = nullptr;
int ntags = 0;
int ntotwarps = 0;
int nblocks = 0;
bool defined = false;
static constexpr int nthreads = 256;

TagVector () = default;

TagVector (Vector<TagType> const& tags) {
define(tags);
}

~TagVector () {
if (defined) {
undefine();
}
}

TagVector (const TagVector& other) = delete;
TagVector (TagVector&& other) = default;
TagVector& operator= (const TagVector& other) = delete;
TagVector& operator= (TagVector&& other) = default;

[[nodiscard]] bool is_defined () const { return defined; }

void define (Vector<TagType> const& tags) {
if (defined) {
undefine();
}

ntags = tags.size();
if (ntags == 0) {
defined = true;
return;
}

#ifdef AMREX_USE_GPU
Long l_ntotwarps = 0;
ntotwarps = 0;
Vector<int> nwarps;
nwarps.reserve(ntags+1);
for (int i = 0; i < ntags; ++i)
{
auto& tag = tags[i];
nwarps.push_back(ntotwarps);
auto nw = (detail::get_tag_size(tag) + Gpu::Device::warp_size-1) /
Gpu::Device::warp_size;
l_ntotwarps += nw;
ntotwarps += static_cast<int>(nw);
}
nwarps.push_back(ntotwarps);

std::size_t sizeof_tags = ntags*sizeof(TagType);
std::size_t offset_nwarps = Arena::align(sizeof_tags);
std::size_t sizeof_nwarps = (ntags+1)*sizeof(int);
std::size_t total_buf_size = offset_nwarps + sizeof_nwarps;

h_buffer = (char*)The_Pinned_Arena()->alloc(total_buf_size);
d_buffer = (char*)The_Arena()->alloc(total_buf_size);

std::memcpy(h_buffer, tags.data(), sizeof_tags);
std::memcpy(h_buffer+offset_nwarps, nwarps.data(), sizeof_nwarps);
Gpu::htod_memcpy_async(d_buffer, h_buffer, total_buf_size);

d_tags = reinterpret_cast<TagType*>(d_buffer);
d_nwarps = reinterpret_cast<int*>(d_buffer+offset_nwarps);

constexpr int nwarps_per_block = nthreads/Gpu::Device::warp_size;
nblocks = (ntotwarps + nwarps_per_block-1) / nwarps_per_block;

defined = true;

amrex::ignore_unused(l_ntotwarps);
AMREX_ALWAYS_ASSERT(l_ntotwarps+nwarps_per_block-1 < Long(std::numeric_limits<int>::max()));
#else
std::size_t sizeof_tags = ntags*sizeof(TagType);
h_buffer = (char*)The_Pinned_Arena()->alloc(sizeof_tags);

std::memcpy(h_buffer, tags.data(), sizeof_tags);

d_tags = reinterpret_cast<TagType*>(h_buffer);

defined = true;
#endif
}

void undefine () {
if (defined) {
Gpu::streamSynchronize();
The_Pinned_Arena()->free(h_buffer);
The_Arena()->free(d_buffer);
h_buffer = nullptr;
d_buffer = nullptr;
d_tags = nullptr;
d_nwarps = nullptr;
ntags = 0;
ntotwarps = 0;
nblocks = 0;
defined = false;
}
}
};

namespace detail {

#ifdef AMREX_USE_GPU

template <typename T, typename F>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
std::enable_if_t<std::is_same<std::decay_t<decltype(std::declval<T>().box())>, Box>::value, void>
std::enable_if_t<std::is_same_v<std::decay_t<decltype(std::declval<T>().box())>, Box>, void>
tagparfor_call_f (
#ifdef AMREX_USE_SYCL
sycl::nd_item<1> const& item,
Expand All @@ -150,7 +272,7 @@ tagparfor_call_f (

template <typename T, typename F>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
std::enable_if_t<std::is_integral<std::decay_t<decltype(std::declval<T>().size())> >::value, void>
std::enable_if_t<std::is_integral_v<std::decay_t<decltype(std::declval<T>().size())> >, void>
tagparfor_call_f (
#ifdef AMREX_USE_SYCL
sycl::nd_item<1> const& item,
Expand All @@ -167,48 +289,19 @@ tagparfor_call_f (

template <class TagType, class F>
void
ParallelFor_doit (Vector<TagType> const& tags, F && f)
ParallelFor_doit (TagVector<TagType> const& tv, F const& f)
{
const int ntags = tags.size();
if (ntags == 0) { return; }

Long l_ntotwarps = 0;
int ntotwarps = 0;
Vector<int> nwarps;
nwarps.reserve(ntags+1);
for (int i = 0; i < ntags; ++i)
{
auto& tag = tags[i];
nwarps.push_back(ntotwarps);
auto nw = (get_tag_size(tag) + Gpu::Device::warp_size-1) / Gpu::Device::warp_size;
l_ntotwarps += nw;
ntotwarps += static_cast<int>(nw);
}
nwarps.push_back(ntotwarps);
AMREX_ALWAYS_ASSERT(tv.is_defined());

std::size_t sizeof_tags = ntags*sizeof(TagType);
std::size_t offset_nwarps = Arena::align(sizeof_tags);
std::size_t sizeof_nwarps = (ntags+1)*sizeof(int);
std::size_t total_buf_size = offset_nwarps + sizeof_nwarps;
if (tv.ntags == 0) { return; }

char* h_buffer = (char*)The_Pinned_Arena()->alloc(total_buf_size);
char* d_buffer = (char*)The_Arena()->alloc(total_buf_size);
const auto d_tags = tv.d_tags;
const auto d_nwarps = tv.d_nwarps;
const auto ntags = tv.ntags;
const auto ntotwarps = tv.ntotwarps;
constexpr auto nthreads = TagVector<TagType>::nthreads;

std::memcpy(h_buffer, tags.data(), sizeof_tags);
std::memcpy(h_buffer+offset_nwarps, nwarps.data(), sizeof_nwarps);
Gpu::htod_memcpy_async(d_buffer, h_buffer, total_buf_size);

auto d_tags = reinterpret_cast<TagType*>(d_buffer);
auto d_nwarps = reinterpret_cast<int*>(d_buffer+offset_nwarps);

constexpr int nthreads = 256;
constexpr int nwarps_per_block = nthreads/Gpu::Device::warp_size;
int nblocks = (ntotwarps + nwarps_per_block-1) / nwarps_per_block;

amrex::ignore_unused(l_ntotwarps);
AMREX_ASSERT(l_ntotwarps+nwarps_per_block-1 < Long(std::numeric_limits<int>::max()));

amrex::launch(nblocks, nthreads, Gpu::gpuStream(),
amrex::launch(tv.nblocks, nthreads, Gpu::gpuStream(),
#ifdef AMREX_USE_SYCL
[=] AMREX_GPU_DEVICE (sycl::nd_item<1> const& item) noexcept
[[sycl::reqd_work_group_size(nthreads)]]
Expand Down Expand Up @@ -241,20 +334,60 @@ ParallelFor_doit (Vector<TagType> const& tags, F && f)
tagparfor_call_f( icell, d_tags[tag_id], f);
#endif
});
}

#else // ifdef AMREX_USE_GPU

Gpu::streamSynchronize();
The_Pinned_Arena()->free(h_buffer);
The_Arena()->free(d_buffer);
template <class TagType, class F>
void
ParallelFor_doit (TagVector<TagType> const& tv, F const& f)
{
AMREX_ALWAYS_ASSERT(tv.is_defined());

if (tv.ntags == 0) { return; }

const auto d_tags = tv.d_tags;
const auto ntags = tv.ntags;

#ifdef AMREX_USE_OMP
#pragma omp parallel
#endif
for (int itag = 0; itag < ntags; ++itag) {

const auto& t = d_tags[itag];

if constexpr (is_box_tag(t)) {
const auto lo = amrex::lbound(t.box());
const auto hi = amrex::ubound(t.box());

for (int k = lo.z; k <= hi.z; ++k) {
for (int j = lo.y; j <= hi.y; ++j) {
AMREX_PRAGMA_SIMD
for (int i = lo.x; i <= hi.x; ++i) {
f(0, 1, i, j, k, t);
}
}
}
} else {
const auto size = t.size();

AMREX_PRAGMA_SIMD
for (int i = 0; i < size; ++i) {
f(i, size, t);
}
}
}
}

#endif

}

template <class TagType, class F>
std::enable_if_t<std::is_same<std::decay_t<decltype(std::declval<TagType>().box())>,
Box>::value>
ParallelFor (Vector<TagType> const& tags, int ncomp, F && f)
std::enable_if_t<std::is_same_v<std::decay_t<decltype(std::declval<TagType>().box())>, Box>>
ParallelFor (TagVector<TagType> const& tv, int ncomp, F const& f)
{
detail::ParallelFor_doit(tags,
detail::ParallelFor_doit(tv,
[=] AMREX_GPU_DEVICE (
#ifdef AMREX_USE_SYCL
sycl::nd_item<1> const& /*item*/,
Expand All @@ -270,10 +403,10 @@ ParallelFor (Vector<TagType> const& tags, int ncomp, F && f)
}

template <class TagType, class F>
std::enable_if_t<std::is_same<std::decay_t<decltype(std::declval<TagType>().box())>, Box>::value, void>
ParallelFor (Vector<TagType> const& tags, F && f)
std::enable_if_t<std::is_same_v<std::decay_t<decltype(std::declval<TagType>().box())>, Box>, void>
ParallelFor (TagVector<TagType> const& tv, F const& f)
{
detail::ParallelFor_doit(tags,
detail::ParallelFor_doit(tv,
[=] AMREX_GPU_DEVICE (
#ifdef AMREX_USE_SYCL
sycl::nd_item<1> const& /*item*/,
Expand All @@ -287,10 +420,10 @@ ParallelFor (Vector<TagType> const& tags, F && f)
}

template <class TagType, class F>
std::enable_if_t<std::is_integral<std::decay_t<decltype(std::declval<TagType>().size())> >::value, void>
ParallelFor (Vector<TagType> const& tags, F && f)
std::enable_if_t<std::is_integral_v<std::decay_t<decltype(std::declval<TagType>().size())> >, void>
ParallelFor (TagVector<TagType> const& tv, F const& f)
{
detail::ParallelFor_doit(tags,
detail::ParallelFor_doit(tv,
[=] AMREX_GPU_DEVICE (
#ifdef AMREX_USE_SYCL
sycl::nd_item<1> const& /*item*/,
Expand All @@ -303,7 +436,29 @@ ParallelFor (Vector<TagType> const& tags, F && f)
});
}

#endif
template <class TagType, class F>
std::enable_if_t<std::is_same_v<std::decay_t<decltype(std::declval<TagType>().box())>, Box>>
ParallelFor (Vector<TagType> const& tags, int ncomp, F && f)
{
TagVector<TagType> tv{tags};
ParallelFor(tv, ncomp, std::forward<F>(f));
}

template <class TagType, class F>
std::enable_if_t<std::is_same_v<std::decay_t<decltype(std::declval<TagType>().box())>, Box>, void>
ParallelFor (Vector<TagType> const& tags, F && f)
{
TagVector<TagType> tv{tags};
ParallelFor(tv, std::forward<F>(f));
}

template <class TagType, class F>
std::enable_if_t<std::is_integral_v<std::decay_t<decltype(std::declval<TagType>().size())> >, void>
ParallelFor (Vector<TagType> const& tags, F && f)
{
TagVector<TagType> tv{tags};
ParallelFor(tv, std::forward<F>(f));
}

}

Expand Down
Loading