Skip to content

Commit

Permalink
combine saxpy with xpay
Browse files Browse the repository at this point in the history
  • Loading branch information
mbkuhn committed Mar 10, 2025
1 parent 247628b commit 73018e9
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 2 deletions.
74 changes: 74 additions & 0 deletions Src/Base/AMReX_FabArray.H
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,24 @@ public:
value_type a, const FabArray<FAB>& x, int xcomp,
value_type b, const FabArray<FAB>& y, int ycomp,
int dstcomp, int numcomp, const IntVect& nghost);

/**
* \brief y = x2+a2*(y+a1*x1)
*
* \param y FabArray y
* \param a_saxpy scalar a_saxpy
* \param x_saxpy FabArray x_saxpy
* \param a_xpay scalar a_xpay
* \param x_xpay FabArray x_xpay
* \param xcomp starting component of x
* \param ycomp starting component of y
* \param ncomp number of components
* \param nghost number of ghost cells
*/
template <class F=FAB, std::enable_if_t<IsBaseFab<F>::value,int> = 0>
static void Saxpy_Xpay (FabArray<FAB>& y, value_type a_saxpy, FabArray<FAB> const& x_saxpy,
value_type a_xpay, FabArray<FAB> const& x_xpay,
int xcomp, int ycomp, int ncomp, IntVect const& nghost);
};


Expand Down Expand Up @@ -2918,6 +2936,62 @@ FabArray<FAB>::Xpay (FabArray<FAB>& y, value_type a, FabArray<FAB> const& x,
}
}

template <class FAB>
template <class F, std::enable_if_t<IsBaseFab<F>::value,int> FOO>
void FabArray<FAB>::Saxpy_Xpay (FabArray<FAB>& y, value_type a_saxpy, FabArray<FAB> const& x_saxpy,
value_type a_xpay, FabArray<FAB> const& x_xpay,
int xcomp, int ycomp, int ncomp, IntVect const& nghost)
{
AMREX_ASSERT(y.boxArray() == x_saxpy.boxArray());
AMREX_ASSERT(y.distributionMap == x_saxpy.distributionMap);
AMREX_ASSERT(y.nGrowVect().allGE(nghost) && x_saxpy.nGrowVect().allGE(nghost));

AMREX_ASSERT(y.boxArray() == x_xpay.boxArray());
AMREX_ASSERT(y.distributionMap == x_xpay.distributionMap);
AMREX_ASSERT(y.nGrowVect().allGE(nghost) && x_xpay.nGrowVect().allGE(nghost));

BL_PROFILE("FabArray::Saxpy_Xpay()");

#ifdef AMREX_USE_GPU
if (Gpu::inLaunchRegion() && y.isFusingCandidate()) {
auto const& yma = y.arrays();
auto const& xma_s = x_saxpy.const_arrays();
auto const& xma_x = x_xpay.const_arrays();
ParallelFor(y, nghost, ncomp,
[=] AMREX_GPU_DEVICE (int box_no, int i, int j, int k, int n) noexcept
{
yma[box_no](i,j,k,ycomp+n) += a_saxpy * xma_s[box_no](i,j,k,xcomp+n);
yma[box_no](i,j,k,ycomp+n) = xma_x[box_no](i,j,k,n+xcomp)
+ a_xpay * yma[box_no](i,j,k,n+ycomp);
});
if (!Gpu::inNoSyncRegion()) {
Gpu::streamSynchronize();
}
} else
#endif
{
#ifdef AMREX_USE_OMP
#pragma omp parallel if (Gpu::notInLaunchRegion())
#endif
for (MFIter mfi(y,TilingIfNotGPU()); mfi.isValid(); ++mfi)
{
const Box& bx = mfi.growntilebox(nghost);

if (bx.ok()) {
auto const& xfab_s = x_saxpy.const_array(mfi);
auto const& xfab_x = x_xpay.const_array(mfi);
auto const& yfab = y.array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D( bx, ncomp, i, j, k, n,
{
yfab(i,j,k,ycomp+n) += a_saxpy * xfab_s(i,j,k,xcomp+n);
yfab(i,j,k,ycomp+n) = xfab_x(i,j,k,xcomp+n)
+ a_xpay * yfab(i,j,k,ycomp+n);
});
}
}
}
}

template <class FAB>
template <class F, std::enable_if_t<IsBaseFab<F>::value,int> FOO>
void
Expand Down
9 changes: 9 additions & 0 deletions Src/Base/AMReX_FabArrayUtility.H
Original file line number Diff line number Diff line change
Expand Up @@ -1858,6 +1858,15 @@ void Xpay (MF& dst, typename MF::value_type a, MF const& src, int scomp, int dco
MF::Xpay(dst, a, src, scomp, dcomp, ncomp, nghost);
}

//! dst += a * src; dst = src + a * dst
template <class MF, std::enable_if_t<IsMultiFabLike_v<MF>,int> = 0>
void Saxpy_Xpay (MF& dst, typename MF::value_type a_saxpy, MF const& src_saxpy,
typename MF::value_type a_xpay, MF const& src_xpay, int scomp, int dcomp,
int ncomp, IntVect const& nghost)
{
MF::Saxpy_Xpay(dst, a_saxpy, src_saxpy, a_xpay, src_xpay, scomp, dcomp, ncomp, nghost);
}

//! dst = a*src_a + b*src_b
template <class MF, std::enable_if_t<IsMultiFabLike_v<MF>,int> = 0>
void LinComb (MF& dst,
Expand Down
15 changes: 15 additions & 0 deletions Src/Base/AMReX_MultiFab.H
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,21 @@ public:

using FabArray<FArrayBox>::Xpay;

/**
* \brief dst = src2+a2*(dst+a1*src1)
*/
static void Saxpy_Xpay (MultiFab& dst,
Real a_saxpy,
const MultiFab& src_saxpy,
Real a_xpay,
const MultiFab& src_xpay,
int srccomp,
int dstcomp,
int numcomp,
int nghost);

using FabArray<FArrayBox>::Saxpy_Xpay;

/**
* \brief dst = a*x + b*y
*/
Expand Down
8 changes: 8 additions & 0 deletions Src/Base/AMReX_MultiFab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,14 @@ MultiFab::Xpay (MultiFab& dst, Real a, const MultiFab& src,
Xpay(dst,a,src,srccomp,dstcomp,numcomp,IntVect(nghost));
}

void
MultiFab::Saxpy_Xpay (MultiFab& dst, Real a_saxpy, const MultiFab& src_saxpy,
Real a_xpay, const MultiFab& src_xpay,
int srccomp, int dstcomp, int numcomp, int nghost)
{
Saxpy_Xpay(dst,a_saxpy,src_saxpy,a_xpay,src_xpay,srccomp,dstcomp,numcomp,IntVect(nghost));
}

void
MultiFab::LinComb (MultiFab& dst,
Real a, const MultiFab& x, int xcomp,
Expand Down
5 changes: 3 additions & 2 deletions Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,9 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
else
{
const RT beta = (rho/rho_1)*(alpha/omega);
Saxpy(p, -omega, v, 0, 0, ncomp, nghost); // p += -omega*v
Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta*p
// two operations: p += -omega*v; p = r + beta*p
// same as: p = r + beta*(p - omega*v)
Saxpy_Xpay(p, -omega, v, beta, r, 0, 0, ncomp, nghost);
}
Lp.apply(amrlev, mglev, v, p, MLLinOpT<MF>::BCMode::Homogeneous, MLLinOpT<MF>::StateMode::Correction);
Lp.normalize(amrlev, mglev, v);
Expand Down

0 comments on commit 73018e9

Please sign in to comment.