diff --git a/Src/Base/AMReX_FabArray.H b/Src/Base/AMReX_FabArray.H index 610d9c6ec2..ad3da00c28 100644 --- a/Src/Base/AMReX_FabArray.H +++ b/Src/Base/AMReX_FabArray.H @@ -1542,6 +1542,25 @@ public: static void Saxpy_Xpay (FabArray& y, value_type a_saxpy, FabArray const& x_saxpy, value_type a_xpay, FabArray const& x_xpay, int xcomp, int ycomp, int ncomp, IntVect const& nghost); + + /** + * \brief y1 += a1*x1; y2 += a2*x2; + * + * \param y1 FabArray y1 + * \param a1 scalar a1 + * \param x1 FabArray x1 + * \param y2 FabArray y2 + * \param a2 scalar a2 + * \param x2 FabArray x2 + * \param xcomp starting component of x1, x2 + * \param ycomp starting component of y1, y2 + * \param ncomp number of components + * \param nghost number of ghost cells + */ + template ::value,int> = 0> + static void Saxpy_Saxpy (FabArray& y1, value_type a1, FabArray const& x1, + FabArray& y2, value_type a2, FabArray const& x2, + int xcomp, int ycomp, int ncomp, IntVect const& nghost); }; @@ -2992,6 +3011,62 @@ void FabArray::Saxpy_Xpay (FabArray& y, value_type a_saxpy, FabArray +template ::value,int> FOO> +void FabArray::Saxpy_Saxpy (FabArray& y1, value_type a1, FabArray const& x1, + FabArray& y2, value_type a2, FabArray const& x2, + int xcomp, int ycomp, int ncomp, IntVect const& nghost) +{ + AMREX_ASSERT(y1.boxArray() == x1.boxArray()); + AMREX_ASSERT(y1.distributionMap == x1.distributionMap); + AMREX_ASSERT(y1.nGrowVect().allGE(nghost) && x1.nGrowVect().allGE(nghost)); + + AMREX_ASSERT(y2.boxArray() == x2.boxArray()); + AMREX_ASSERT(y2.distributionMap == x2.distributionMap); + AMREX_ASSERT(y2.nGrowVect().allGE(nghost) && x2.nGrowVect().allGE(nghost)); + + BL_PROFILE("FabArray::Saxpy_Saxpy()"); + +#ifdef AMREX_USE_GPU + if (Gpu::inLaunchRegion() && y1.isFusingCandidate()) { + auto const& y1ma = y1.arrays(); + auto const& x1ma = x1.const_arrays(); + auto const& y2ma = y2.arrays(); + auto const& x2ma = x2.const_arrays(); + ParallelFor(y1, nghost, ncomp, + [=] AMREX_GPU_DEVICE (int box_no, int i, int j, int k, int n) noexcept + { + y1ma[box_no](i,j,k,ycomp+n) += a1 * x1ma[box_no](i,j,k,xcomp+n); + y2ma[box_no](i,j,k,ycomp+n) += a2 * x2ma[box_no](i,j,k,xcomp+n); + }); + if (!Gpu::inNoSyncRegion()) { + Gpu::streamSynchronize(); + } + } else +#endif + { +#ifdef AMREX_USE_OMP +#pragma omp parallel if (Gpu::notInLaunchRegion()) +#endif + for (MFIter mfi(y1,TilingIfNotGPU()); mfi.isValid(); ++mfi) + { + const Box& bx = mfi.growntilebox(nghost); + + if (bx.ok()) { + auto const& x1fab = x1.const_array(mfi); + auto const& y1fab = y1.array(mfi); + auto const& x2fab = x2.const_array(mfi); + auto const& y2fab = y2.array(mfi); + AMREX_HOST_DEVICE_PARALLEL_FOR_4D( bx, ncomp, i, j, k, n, + { + y1fab(i,j,k,ycomp+n) += a1 * x1fab(i,j,k,xcomp+n); + y2fab(i,j,k,ycomp+n) += a2 * x2fab(i,j,k,xcomp+n); + }); + } + } + } +} + template template ::value,int> FOO> void diff --git a/Src/Base/AMReX_FabArrayUtility.H b/Src/Base/AMReX_FabArrayUtility.H index e7b0a9e61d..67128054b2 100644 --- a/Src/Base/AMReX_FabArrayUtility.H +++ b/Src/Base/AMReX_FabArrayUtility.H @@ -1867,6 +1867,15 @@ void Saxpy_Xpay (MF& dst, typename MF::value_type a_saxpy, MF const& src_saxpy, MF::Saxpy_Xpay(dst, a_saxpy, src_saxpy, a_xpay, src_xpay, scomp, dcomp, ncomp, nghost); } +//! dst1 += a1 * src1; dst1 += a1 * src1 +template ,int> = 0> +void Saxpy_Saxpy (MF& dst1, typename MF::value_type a1, MF const& src1, + MF& dst2, typename MF::value_type a2, MF const& src2, int scomp, int dcomp, + int ncomp, IntVect const& nghost) +{ + MF::Saxpy_Saxpy(dst1, a1, src1, dst2, a2, src2, scomp, dcomp, ncomp, nghost); +} + //! dst = a*src_a + b*src_b template ,int> = 0> void LinComb (MF& dst, diff --git a/Src/Base/AMReX_MultiFab.H b/Src/Base/AMReX_MultiFab.H index 2466370b7f..a16950f553 100644 --- a/Src/Base/AMReX_MultiFab.H +++ b/Src/Base/AMReX_MultiFab.H @@ -635,6 +635,22 @@ public: using FabArray::Saxpy_Xpay; + /** + * \brief dst1 += a1*src1; dst2 += a2*src2 + */ + static void Saxpy_Saxpy (MultiFab& dst1, + Real a1, + const MultiFab& src1, + MultiFab& dst2, + Real a2, + const MultiFab& src2, + int srccomp, + int dstcomp, + int numcomp, + int nghost); + + using FabArray::Saxpy_Saxpy; + /** * \brief dst = a*x + b*y */ diff --git a/Src/Base/AMReX_MultiFab.cpp b/Src/Base/AMReX_MultiFab.cpp index 4dc10c281d..064fc2bd63 100644 --- a/Src/Base/AMReX_MultiFab.cpp +++ b/Src/Base/AMReX_MultiFab.cpp @@ -363,6 +363,14 @@ MultiFab::Saxpy_Xpay (MultiFab& dst, Real a_saxpy, const MultiFab& src_saxpy, Saxpy_Xpay(dst,a_saxpy,src_saxpy,a_xpay,src_xpay,srccomp,dstcomp,numcomp,IntVect(nghost)); } +void +MultiFab::Saxpy_Saxpy (MultiFab& dst1, Real a1, const MultiFab& src1, + MultiFab& dst2, Real a2, const MultiFab& src2, + int srccomp, int dstcomp, int numcomp, int nghost) +{ + Saxpy_Saxpy(dst1,a1,src1,dst2,a2,src2,srccomp,dstcomp,numcomp,IntVect(nghost)); +} + void MultiFab::LinComb (MultiFab& dst, Real a, const MultiFab& x, int xcomp, diff --git a/Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H b/Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H index d67f4e83de..2e7065b6cc 100644 --- a/Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H +++ b/Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H @@ -182,8 +182,9 @@ MLCGSolverT::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) { ret = 2; break; } - Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p - Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r += -alpha * v + // sol += alpha * p; r += -alpha * v + Saxpy_Saxpy(sol, alpha, p, r, -alpha, v, 0, 0, ncomp, nghost); + rnorm = norm_inf(r); @@ -218,8 +219,8 @@ MLCGSolverT::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) { ret = 3; break; } - Saxpy(sol, omega, r, 0, 0, ncomp, nghost); // sol += omega * r - Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r += -omega * t + // sol += omega * r; r += -omega * t + Saxpy_Saxpy(sol, omega, r, r, -omega, t, 0, 0, ncomp, nghost); rnorm = norm_inf(r); @@ -360,8 +361,8 @@ MLCGSolverT::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs) << " rho " << rho << " alpha " << alpha << '\n'; } - Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p - Saxpy(r, -alpha, q, 0, 0, ncomp, nghost); // r += -alpha * q + // sol += alpha * p; r += -alpha * q + Saxpy_Saxpy(sol, alpha, p, r, -alpha, q, 0, 0, ncomp, nghost); rnorm = norm_inf(r); if ( verbose > 2 )