Skip to content

Commit

Permalink
fix cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderSinn committed Feb 21, 2025
1 parent 9be5ad5 commit dc9c62d
Showing 1 changed file with 77 additions and 46 deletions.
123 changes: 77 additions & 46 deletions Src/Base/AMReX_TagParallelFor.H
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ namespace detail {
return tag.size();
}

template <typename T>
constexpr
std::enable_if_t<std::is_same<std::decay_t<decltype(std::declval<T>().box())>, Box>::value,
bool>
is_box_tag (T const&) { return true; }

template <typename T>
constexpr
std::enable_if_t<std::is_integral<std::decay_t<decltype(std::declval<T>().size())> >::value,
bool>
is_box_tag (T const&) { return false; }

}

template <class TagType>
Expand Down Expand Up @@ -166,6 +178,7 @@ struct TagVector {
return;
}

#ifdef AMREX_USE_GPU
Long l_ntotwarps = 0;
ntotwarps = 0;
Vector<int> nwarps;
Expand Down Expand Up @@ -203,6 +216,16 @@ struct TagVector {

amrex::ignore_unused(l_ntotwarps);
AMREX_ALWAYS_ASSERT(l_ntotwarps+nwarps_per_block-1 < Long(std::numeric_limits<int>::max()));
#else
std::size_t sizeof_tags = ntags*sizeof(TagType);
h_buffer = (char*)The_Pinned_Arena()->alloc(sizeof_tags);

std::memcpy(h_buffer, tags.data(), sizeof_tags);

d_tags = reinterpret_cast<TagType*>(h_buffer);

defined = true;
#endif
}

void undefine () {
Expand All @@ -222,10 +245,10 @@ struct TagVector {
}
};

#ifdef AMREX_USE_GPU

namespace detail {

#ifdef AMREX_USE_GPU

template <typename T, typename F>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
std::enable_if_t<std::is_same<std::decay_t<decltype(std::declval<T>().box())>, Box>::value, void>
Expand Down Expand Up @@ -276,10 +299,10 @@ ParallelFor_doit (TagVector<TagType> const& tv, F && f)

if (tv.ntags == 0) { return; }

auto d_tags = tv.d_tags;
auto d_nwarps = tv.d_nwarps;
auto ntags = tv.ntags;
auto ntotwarps = tv.ntotwarps;
const auto d_tags = tv.d_tags;
const auto d_nwarps = tv.d_nwarps;
const auto ntags = tv.ntags;
const auto ntotwarps = tv.ntotwarps;
constexpr auto nthreads = tv.nthreads;

amrex::launch(tv.nblocks, nthreads, Gpu::gpuStream(),
Expand Down Expand Up @@ -317,6 +340,51 @@ ParallelFor_doit (TagVector<TagType> const& tv, F && f)
});
}

#else // ifdef AMREX_USE_GPU

template <class TagType, class F>
void
ParallelFor_doit (TagVector<TagType> const& tv, F &&)
{
AMREX_ALWAYS_ASSERT(tv.is_defined());

if (tv.ntags == 0) { return; }

const auto d_tags = tv.d_tags;
const auto ntags = tv.ntags;

#ifdef AMREX_USE_OMP
#pragma omp parallel
#endif
for (int itag = 0; itag < ntags; ++itag) {

const auto& t = d_tags[itag];

if constexpr (is_box_tag(t)) {
const auto lo = amrex::lbound(t.box());
const auto hi = amrex::ubound(t.box());

for (int k = lo.z; k <= hi.z; ++k) {
for (int j = lo.y; j <= hi.y; ++j) {
AMREX_PRAGMA_SIMD
for (int i = lo.x; i <= hi.x; ++i) {
f(0, 1, i, j, k, t);
}
}
}
} else {
const auto size = t.size();

AMREX_PRAGMA_SIMD
for (int i = 0; i < size; ++i) {
f(i, size, t);
}
}
}
}

#endif

}

template <class TagType, class F>
Expand Down Expand Up @@ -379,62 +447,25 @@ std::enable_if_t<std::is_same<std::decay_t<decltype(std::declval<TagType>().box(
ParallelFor (Vector<TagType> const& tags, int ncomp, F && f)
{
TagVector<TagType> tv{tags};

detail::ParallelFor_doit(tv,
[=] AMREX_GPU_DEVICE (
#ifdef AMREX_USE_SYCL
sycl::nd_item<1> const& /*item*/,
#endif
int icell, int ncells, int i, int j, int k, TagType const& tag) noexcept
{
if (icell < ncells) {
for (int n = 0; n < ncomp; ++n) {
f(i,j,k,n,tag);
}
}
});
ParallelFor(tv, ncomp, f);
}

template <class TagType, class F>
std::enable_if_t<std::is_same<std::decay_t<decltype(std::declval<TagType>().box())>, Box>::value, void>
ParallelFor (Vector<TagType> const& tags, F && f)
{
TagVector<TagType> tv{tags};

detail::ParallelFor_doit(tv,
[=] AMREX_GPU_DEVICE (
#ifdef AMREX_USE_SYCL
sycl::nd_item<1> const& /*item*/,
#endif
int icell, int ncells, int i, int j, int k, TagType const& tag) noexcept
{
if (icell < ncells) {
f(i,j,k,tag);
}
});
ParallelFor(tv, f);
}

template <class TagType, class F>
std::enable_if_t<std::is_integral<std::decay_t<decltype(std::declval<TagType>().size())> >::value, void>
ParallelFor (Vector<TagType> const& tags, F && f)
{
TagVector<TagType> tv{tags};

detail::ParallelFor_doit(tv,
[=] AMREX_GPU_DEVICE (
#ifdef AMREX_USE_SYCL
sycl::nd_item<1> const& /*item*/,
#endif
int icell, int ncells, TagType const& tag) noexcept
{
if (icell < ncells) {
f(icell,tag);
}
});
ParallelFor(tv, f);
}

#endif

}

#endif

0 comments on commit dc9c62d

Please sign in to comment.