Skip to content

Commit 06f2a75

Browse files
committed
Start MPI
1 parent ca53939 commit 06f2a75

File tree

3 files changed

+193
-10
lines changed

3 files changed

+193
-10
lines changed

Src/LinearSolvers/AMReX_SpMV.H

+4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ void SpMV (AlgVector<T>& y, SpMatrix<T> const& A, AlgVector<T> const& x)
2626
AMREX_ASSERT(x.numLocalRows() == y.numLocalRows());
2727
AMREX_ASSERT(x.numGlobalRows() == y.numGlobalRows());
2828

29+
A.startComm(x);
30+
2931
T * AMREX_RESTRICT py = y.data();
3032
T const* AMREX_RESTRICT px = x.data();
3133
T const* AMREX_RESTRICT mat = A.data();
@@ -163,6 +165,8 @@ void SpMV (AlgVector<T>& y, SpMatrix<T> const& A, AlgVector<T> const& x)
163165
}
164166

165167
#endif
168+
169+
A.finishComm(y);
166170
}
167171

168172
}

Src/LinearSolvers/AMReX_SpMatrix.H

+129-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <AMReX_Config.H>
44

55
#include <AMReX_AlgPartition.H>
6+
#include <AMReX_AlgVector.H>
67
#include <AMReX_Gpu.H>
78
#include <AMReX_INT.H>
89
#include <AMReX_Scan.H>
@@ -57,10 +58,19 @@ public:
5758
template <typename F>
5859
void setVal (F const& f);
5960

61+
template <typename U> friend void SpMV(AlgVector<U>& y, SpMatrix<U> const& A, AlgVector<U> const& x);
62+
6063
private:
6164

6265
void define_doit (int nnz);
6366

67+
void startComm (AlgVector<T> const& x) const;
68+
void finishComm (AlgVector<T>& y) const;
69+
70+
#ifdef AMREX_USE_MPI
71+
void prepare_for_comm ();
72+
#endif
73+
6474
struct CSR {
6575
Vec<T> mat;
6676
Vec<Long> col_index;
@@ -71,7 +81,12 @@ private:
7181
AlgPartition m_partition;
7282
Long m_row_begin = 0;
7383
Long m_row_end = 0;
74-
CSR m_data; // We might need two CSRs, one for local data, the other for remote data
84+
CSR m_data;
85+
86+
#ifdef AMREX_USE_MPI
87+
CSR m_data_remote;
88+
bool m_comm_prepared = false;
89+
#endif
7590
};
7691

7792
template <typename T, template<typename> class Allocator>
@@ -149,6 +164,119 @@ void SpMatrix<T,Allocator>::setVal (F const& f)
149164
});
150165
}
151166

167+
template <typename T, template<typename> class Allocator>
168+
void SpMatrix<T,Allocator>::startComm (AlgVector<T> const& x) const
169+
{
170+
if (this->numLocalRows() == 0) { return; }
171+
172+
#ifndef AMREX_USE_MPI
173+
amrex::ignore_unused(x);
174+
#else
175+
if (this->numLocalRows() == this->numGlobalRows()) { return; }
176+
177+
const_cast<SpMatrix<T,Allocator>*>(this)->prepare_for_comm();
178+
179+
#endif
180+
}
181+
182+
template <typename T, template<typename> class Allocator>
183+
void SpMatrix<T,Allocator>::finishComm (AlgVector<T>& y) const
184+
{
185+
if (this->numLocalRows() == 0) { return; }
186+
187+
#ifndef AMREX_USE_MPI
188+
amrex::ignore_unused(y);
189+
#else
190+
if (this->numLocalRows() == this->numGlobalRows()) { return; }
191+
#endif
192+
}
193+
194+
#ifdef AMREX_USE_MPI
195+
196+
template <typename T, template<typename> class Allocator>
197+
void SpMatrix<T,Allocator>::prepare_for_comm ()
198+
{
199+
if (m_comm_prepared) { return; }
200+
201+
// First, we need to split the matrix into two parts, a square matrix
202+
// for pure local operations and another part for remote operations.
203+
204+
Long all_nnz = m_data.nnz;
205+
Long local_nnz;
206+
Gpu::DeviceVector<Long> pfsum(all_nnz);
207+
auto* p_pfsum = pfsum.data();
208+
auto row_begin = m_row_begin;
209+
auto row_end = m_row_end;
210+
if (m_data.nnz < Long(std::numeric_limits<int>::max())) {
211+
auto const* pcol = m_data.col_index.data();
212+
local_nnz = Scan::PrefixSum<Long>(int(all_nnz),
213+
[=] AMREX_GPU_DEVICE (int i) -> Long {
214+
return (pcol[i] >= row_begin &&
215+
pcol[i] < row_end); },
216+
[=] AMREX_GPU_DEVICE (int i, Long const& x) {
217+
p_pfsum[i] = x; },
218+
Scan::Type::exclusive, Scan::retSum);
219+
} else {
220+
auto const* pcol = m_data.col_index.data();
221+
local_nnz = Scan::PrefixSum<Long>(all_nnz,
222+
[=] AMREX_GPU_DEVICE (Long i) -> Long {
223+
return (pcol[i] >= row_begin &&
224+
pcol[i] < row_end); },
225+
[=] AMREX_GPU_DEVICE (Long i, Long const& x) {
226+
p_pfsum[i] = x; },
227+
Scan::Type::exclusive, Scan::retSum);
228+
}
229+
230+
m_data.nnz = local_nnz;
231+
Long remote_nnz = all_nnz - local_nnz;
232+
m_data_remote.nnz = remote_nnz;
233+
234+
if (local_nnz != all_nnz) {
235+
m_data_remote.mat.resize(remote_nnz);
236+
m_data_remote.col_index.resize(remote_nnz);
237+
Vec<T> new_mat(local_nnz);
238+
Vec<Long> new_col(local_nnz);
239+
auto const* pmat = m_data.mat.data();
240+
auto const* pcol = m_data.col_index.data();
241+
auto* pmat_l = new_mat.data();
242+
auto* pcol_l = new_col.data();
243+
auto* pmat_r = m_data_remote.mat.data();
244+
auto* pcol_r = m_data_remote.col_index.data();
245+
ParallelFor(all_nnz, [=] AMREX_GPU_DEVICE (Long i)
246+
{
247+
auto ps = p_pfsum[i];
248+
auto local = (pcol[i] >= row_begin &&
249+
pcol[i] < row_end);
250+
if (local) {
251+
pmat_l[ps] = pmat[i];
252+
pcol_l[ps] = pcol[i] - row_begin; // shift the column index to local
253+
} else {
254+
pmat_r[i-ps] = pmat[i];
255+
pcol_r[i-ps] = pcol[i];
256+
}
257+
});
258+
auto noffset = Long(m_data.row_offset.size());
259+
auto* pro = m_data.row_offset.data();
260+
ParallelFor(noffset, [=] AMREX_GPU_DEVICE (Long i)
261+
{
262+
if (i < noffset-1) {
263+
pro[i] = p_pfsum[pro[i]];
264+
} else {
265+
pro[i] = local_nnz;
266+
}
267+
});
268+
Gpu::streamSynchronize();
269+
m_data.mat.swap(new_mat);
270+
m_data.col_index.swap(new_col);
271+
272+
// xxxxx TODO: still need to work on m_data_remote
273+
}
274+
275+
m_comm_prepared = true;
276+
}
277+
278+
#endif
279+
152280
}
153281

154282
#endif

Tests/Algebra/GMRES/main.cpp

+60-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include <AMReX_GMRES_MV.H>
22

33
#include <AMReX.H>
4-
#include <AMReX_Random.H>
54

65
using namespace amrex;
76

@@ -13,16 +12,61 @@ int main (int argc, char* argv[])
1312
Long n = domain.numPts();
1413
AlgVector<Real> xvec(n);
1514
AlgVector<Real> bvec(xvec.partition());
15+
AlgVector<Real> exact(xvec.partition());
16+
17+
Real a = Real(1.e-6);
18+
Real dx = Real(2)*amrex::Math::pi<Real>()/Real(domain.length(0));
19+
20+
// The system is a * phi - del dot grad phi.
21+
// Where phi = sin^5(x)*sin^5(y)*sin^5(z)
22+
23+
BoxIndexer box_indexer(domain);
1624

1725
// Initialzie bvec
18-
amrex::FillRandomNormal(bvec.data(), bvec.numLocalRows(), Real(0), Real(1));
26+
{
27+
auto* rhs = bvec.data();
28+
auto* phi = exact.data();
29+
auto nrows = bvec.numLocalRows();
30+
auto ib = bvec.globalBegin();
31+
ParallelFor(nrows, [=] AMREX_GPU_DEVICE (Long lrow)
32+
{
33+
auto row = lrow + ib; // global row index
34+
IntVect cell = box_indexer.intVect(row);
35+
#if (AMREX_SPACEDIM == 1)
36+
auto x = (cell[0]+Real(0.5))*dx;
37+
auto phi0 = Math::powi<5>(std::sin(x));
38+
auto phixm = Math::powi<5>(std::sin(x-dx));
39+
auto phixp = Math::powi<5>(std::sin(x+dx));
40+
rhs[lrow] = a*phi0 + (Real(2)*phi0-phixm-phixp) / (dx*dx);
41+
#elif (AMREX_SPACEDIM == 2)
42+
auto x = (cell[0]+Real(0.5))*dx;
43+
auto y = (cell[1]+Real(0.5))*dx;
44+
auto phi0 = Math::powi<5>(std::sin(x)*std::sin(y));
45+
auto phixm = Math::powi<5>(std::sin(x-dx)*std::sin(y));
46+
auto phixp = Math::powi<5>(std::sin(x+dx)*std::sin(y));
47+
auto phiym = Math::powi<5>(std::sin(x)*std::sin(y-dx));
48+
auto phiyp = Math::powi<5>(std::sin(x)*std::sin(y+dx));
49+
rhs[lrow] = a*phi0 + (Real(4)*phi0-phixm-phixp-phiym-phiyp) / (dx*dx);
50+
#else
51+
auto x = (cell[0]+Real(0.5))*dx;
52+
auto y = (cell[1]+Real(0.5))*dx;
53+
auto z = (cell[2]+Real(0.5))*dx;
54+
auto phi0 = Math::powi<5>(std::sin(x)*std::sin(y)*std::sin(z));
55+
auto phixm = Math::powi<5>(std::sin(x-dx)*std::sin(y)*std::sin(z));
56+
auto phixp = Math::powi<5>(std::sin(x+dx)*std::sin(y)*std::sin(z));
57+
auto phiym = Math::powi<5>(std::sin(x)*std::sin(y-dx)*std::sin(z));
58+
auto phiyp = Math::powi<5>(std::sin(x)*std::sin(y+dx)*std::sin(z));
59+
auto phizm = Math::powi<5>(std::sin(x)*std::sin(y)*std::sin(z-dx));
60+
auto phizp = Math::powi<5>(std::sin(x)*std::sin(y)*std::sin(z+dx));
61+
rhs[lrow] = a*phi0 + (Real(6)*phi0-phixm-phixp-phiym-phiyp-phizm-phizp) / (dx*dx);
62+
#endif
63+
phi[lrow] = phi0;
64+
});
65+
}
1966

2067
// Initial guess
2168
xvec.setVal(0);
2269

23-
BoxIndexer box_indexer(domain);
24-
25-
// a * phi - del dot grad phi. For simplicity, let a=1 and dx=1.
2670
// cross stencil w/ periodic boundaries
2771
auto set_stencil = [=] AMREX_GPU_DEVICE (Long row, Long* col, Real* val)
2872
{
@@ -37,7 +81,7 @@ int main (int argc, char* argv[])
3781
}
3882
Long row2 = domain.index(cell2);
3983
col[i] = row2;
40-
val[i] = Real(-1.0);
84+
val[i] = Real(-1.0)/(dx*dx);
4185
++i;
4286

4387
if (cell[idim] == domain.bigEnd(idim)) {
@@ -47,11 +91,11 @@ int main (int argc, char* argv[])
4791
}
4892
row2 = domain.index(cell2);
4993
col[i] = row2;
50-
val[i] = Real(-1.0);
94+
val[i] = Real(-1.0)/(dx*dx);
5195
++i;
5296
}
5397
col[i] = row;
54-
val[i] = Real(2*AMREX_SPACEDIM+1);
98+
val[i] = Real(2*AMREX_SPACEDIM)/(dx*dx) + a;
5599
};
56100

57101
int num_non_zeros = 2*AMREX_SPACEDIM+1;
@@ -60,7 +104,14 @@ int main (int argc, char* argv[])
60104

61105
GMRES_MV<Real> gmres(&mat);
62106
gmres.setVerbose(2);
63-
gmres.solve(xvec, bvec, Real(1.e-10), Real(0.0));
107+
auto eps = (sizeof(Real) == 4) ? Real(1.e-5) : Real (1.e-12);
108+
gmres.solve(xvec, bvec, eps, Real(0.0));
109+
110+
// Check the solution
111+
amrex::Axpy(xvec, Real(-1.0), exact);
112+
auto error = xvec.norminf();
113+
amrex::Print() << " Max norm error: " << error << "\n";
114+
AMREX_ALWAYS_ASSERT(error*10 < eps);
64115
}
65116
amrex::Finalize();
66117
}

0 commit comments

Comments
 (0)