Skip to content

Commit

Permalink
set up double saxpy's
Browse files Browse the repository at this point in the history
  • Loading branch information
mbkuhn committed Mar 10, 2025
1 parent 73018e9 commit 25951a3
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 6 deletions.
75 changes: 75 additions & 0 deletions Src/Base/AMReX_FabArray.H
Original file line number Diff line number Diff line change
Expand Up @@ -1542,6 +1542,25 @@ public:
static void Saxpy_Xpay (FabArray<FAB>& y, value_type a_saxpy, FabArray<FAB> const& x_saxpy,
value_type a_xpay, FabArray<FAB> const& x_xpay,
int xcomp, int ycomp, int ncomp, IntVect const& nghost);

/**
* \brief y1 += a1*x1; y2 += a2*x2;
*
* \param y1 FabArray y1
* \param a1 scalar a1
* \param x1 FabArray x1
* \param y2 FabArray y2
* \param a2 scalar a2
* \param x2 FabArray x2
* \param xcomp starting component of x1, x2
* \param ycomp starting component of y1, y2
* \param ncomp number of components
* \param nghost number of ghost cells
*/
template <class F=FAB, std::enable_if_t<IsBaseFab<F>::value,int> = 0>
static void Saxpy_Saxpy (FabArray<FAB>& y1, value_type a1, FabArray<FAB> const& x1,
FabArray<FAB>& y2, value_type a2, FabArray<FAB> const& x2,
int xcomp, int ycomp, int ncomp, IntVect const& nghost);
};


Expand Down Expand Up @@ -2992,6 +3011,62 @@ void FabArray<FAB>::Saxpy_Xpay (FabArray<FAB>& y, value_type a_saxpy, FabArray<F
}
}

template <class FAB>
template <class F, std::enable_if_t<IsBaseFab<F>::value,int> FOO>
void FabArray<FAB>::Saxpy_Saxpy (FabArray<FAB>& y1, value_type a1, FabArray<FAB> const& x1,
FabArray<FAB>& y2, value_type a2, FabArray<FAB> const& x2,
int xcomp, int ycomp, int ncomp, IntVect const& nghost)
{
AMREX_ASSERT(y1.boxArray() == x1.boxArray());
AMREX_ASSERT(y1.distributionMap == x1.distributionMap);
AMREX_ASSERT(y1.nGrowVect().allGE(nghost) && x1.nGrowVect().allGE(nghost));

AMREX_ASSERT(y2.boxArray() == x2.boxArray());
AMREX_ASSERT(y2.distributionMap == x2.distributionMap);
AMREX_ASSERT(y2.nGrowVect().allGE(nghost) && x2.nGrowVect().allGE(nghost));

BL_PROFILE("FabArray::Saxpy_Saxpy()");

#ifdef AMREX_USE_GPU
if (Gpu::inLaunchRegion() && y1.isFusingCandidate()) {
auto const& y1ma = y1.arrays();
auto const& x1ma = x1.const_arrays();
auto const& y2ma = y2.arrays();
auto const& x2ma = x2.const_arrays();
ParallelFor(y1, nghost, ncomp,
[=] AMREX_GPU_DEVICE (int box_no, int i, int j, int k, int n) noexcept
{
y1ma[box_no](i,j,k,ycomp+n) += a1 * x1ma[box_no](i,j,k,xcomp+n);
y2ma[box_no](i,j,k,ycomp+n) += a2 * x2ma[box_no](i,j,k,xcomp+n);
});
if (!Gpu::inNoSyncRegion()) {
Gpu::streamSynchronize();
}
} else
#endif
{
#ifdef AMREX_USE_OMP
#pragma omp parallel if (Gpu::notInLaunchRegion())
#endif
for (MFIter mfi(y1,TilingIfNotGPU()); mfi.isValid(); ++mfi)
{
const Box& bx = mfi.growntilebox(nghost);

if (bx.ok()) {
auto const& x1fab = x1.const_array(mfi);
auto const& y1fab = y1.array(mfi);
auto const& x2fab = x2.const_array(mfi);
auto const& y2fab = y2.array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D( bx, ncomp, i, j, k, n,
{
y1fab(i,j,k,ycomp+n) += a1 * x1fab(i,j,k,xcomp+n);
y2fab(i,j,k,ycomp+n) += a2 * x2fab(i,j,k,xcomp+n);
});
}
}
}
}

template <class FAB>
template <class F, std::enable_if_t<IsBaseFab<F>::value,int> FOO>
void
Expand Down
9 changes: 9 additions & 0 deletions Src/Base/AMReX_FabArrayUtility.H
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,15 @@ void Saxpy_Xpay (MF& dst, typename MF::value_type a_saxpy, MF const& src_saxpy,
MF::Saxpy_Xpay(dst, a_saxpy, src_saxpy, a_xpay, src_xpay, scomp, dcomp, ncomp, nghost);
}

//! dst1 += a1 * src1; dst1 += a1 * src1
template <class MF, std::enable_if_t<IsMultiFabLike_v<MF>,int> = 0>
void Saxpy_Saxpy (MF& dst1, typename MF::value_type a1, MF const& src1,
MF& dst2, typename MF::value_type a2, MF const& src2, int scomp, int dcomp,
int ncomp, IntVect const& nghost)
{
MF::Saxpy_Saxpy(dst1, a1, src1, dst2, a2, src2, scomp, dcomp, ncomp, nghost);
}

//! dst = a*src_a + b*src_b
template <class MF, std::enable_if_t<IsMultiFabLike_v<MF>,int> = 0>
void LinComb (MF& dst,
Expand Down
16 changes: 16 additions & 0 deletions Src/Base/AMReX_MultiFab.H
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,22 @@ public:

using FabArray<FArrayBox>::Saxpy_Xpay;

/**
* \brief dst1 += a1*src1; dst2 += a2*src2
*/
static void Saxpy_Saxpy (MultiFab& dst1,
Real a1,
const MultiFab& src1,
MultiFab& dst2,
Real a2,
const MultiFab& src2,
int srccomp,
int dstcomp,
int numcomp,
int nghost);

using FabArray<FArrayBox>::Saxpy_Saxpy;

/**
* \brief dst = a*x + b*y
*/
Expand Down
8 changes: 8 additions & 0 deletions Src/Base/AMReX_MultiFab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,14 @@ MultiFab::Saxpy_Xpay (MultiFab& dst, Real a_saxpy, const MultiFab& src_saxpy,
Saxpy_Xpay(dst,a_saxpy,src_saxpy,a_xpay,src_xpay,srccomp,dstcomp,numcomp,IntVect(nghost));
}

void
MultiFab::Saxpy_Saxpy (MultiFab& dst1, Real a1, const MultiFab& src1,
MultiFab& dst2, Real a2, const MultiFab& src2,
int srccomp, int dstcomp, int numcomp, int nghost)
{
Saxpy_Saxpy(dst1,a1,src1,dst2,a2,src2,srccomp,dstcomp,numcomp,IntVect(nghost));
}

void
MultiFab::LinComb (MultiFab& dst,
Real a, const MultiFab& x, int xcomp,
Expand Down
13 changes: 7 additions & 6 deletions Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,9 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
{
ret = 2; break;
}
Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r += -alpha * v
// sol += alpha * p; r += -alpha * v
Saxpy_Saxpy(sol, alpha, p, r, -alpha, v, 0, 0, ncomp, nghost);


rnorm = norm_inf(r);

Expand Down Expand Up @@ -218,8 +219,8 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
{
ret = 3; break;
}
Saxpy(sol, omega, r, 0, 0, ncomp, nghost); // sol += omega * r
Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r += -omega * t
// sol += omega * r; r += -omega * t
Saxpy_Saxpy(sol, omega, r, r, -omega, t, 0, 0, ncomp, nghost);

rnorm = norm_inf(r);

Expand Down Expand Up @@ -360,8 +361,8 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
<< " rho " << rho
<< " alpha " << alpha << '\n';
}
Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
Saxpy(r, -alpha, q, 0, 0, ncomp, nghost); // r += -alpha * q
// sol += alpha * p; r += -alpha * q
Saxpy_Saxpy(sol, alpha, p, r, -alpha, q, 0, 0, ncomp, nghost);
rnorm = norm_inf(r);

if ( verbose > 2 )
Expand Down

0 comments on commit 25951a3

Please sign in to comment.