3
3
#include < AMReX_Config.H>
4
4
5
5
#include < AMReX_AlgPartition.H>
6
+ #include < AMReX_AlgVector.H>
6
7
#include < AMReX_Gpu.H>
7
8
#include < AMReX_INT.H>
8
9
#include < AMReX_Scan.H>
@@ -57,10 +58,19 @@ public:
57
58
template <typename F>
58
59
void setVal (F const & f);
59
60
61
+ template <typename U> friend void SpMV (AlgVector<U>& y, SpMatrix<U> const & A, AlgVector<U> const & x);
62
+
60
63
private:
61
64
62
65
void define_doit (int nnz);
63
66
67
+ void startComm (AlgVector<T> const & x) const ;
68
+ void finishComm (AlgVector<T>& y) const ;
69
+
70
+ #ifdef AMREX_USE_MPI
71
+ void prepare_for_comm ();
72
+ #endif
73
+
64
74
struct CSR {
65
75
Vec<T> mat;
66
76
Vec<Long> col_index;
@@ -71,7 +81,12 @@ private:
71
81
AlgPartition m_partition;
72
82
Long m_row_begin = 0 ;
73
83
Long m_row_end = 0 ;
74
- CSR m_data; // We might need two CSRs, one for local data, the other for remote data
84
+ CSR m_data;
85
+
86
+ #ifdef AMREX_USE_MPI
87
+ CSR m_data_remote;
88
+ bool m_comm_prepared = false ;
89
+ #endif
75
90
};
76
91
77
92
template <typename T, template <typename > class Allocator >
@@ -149,6 +164,119 @@ void SpMatrix<T,Allocator>::setVal (F const& f)
149
164
});
150
165
}
151
166
167
+ template <typename T, template <typename > class Allocator >
168
+ void SpMatrix<T,Allocator>::startComm (AlgVector<T> const & x) const
169
+ {
170
+ if (this ->numLocalRows () == 0 ) { return ; }
171
+
172
+ #ifndef AMREX_USE_MPI
173
+ amrex::ignore_unused (x);
174
+ #else
175
+ if (this ->numLocalRows () == this ->numGlobalRows ()) { return ; }
176
+
177
+ const_cast <SpMatrix<T,Allocator>*>(this )->prepare_for_comm ();
178
+
179
+ #endif
180
+ }
181
+
182
+ template <typename T, template <typename > class Allocator >
183
+ void SpMatrix<T,Allocator>::finishComm (AlgVector<T>& y) const
184
+ {
185
+ if (this ->numLocalRows () == 0 ) { return ; }
186
+
187
+ #ifndef AMREX_USE_MPI
188
+ amrex::ignore_unused (y);
189
+ #else
190
+ if (this ->numLocalRows () == this ->numGlobalRows ()) { return ; }
191
+ #endif
192
+ }
193
+
194
+ #ifdef AMREX_USE_MPI
195
+
196
+ template <typename T, template <typename > class Allocator >
197
+ void SpMatrix<T,Allocator>::prepare_for_comm ()
198
+ {
199
+ if (m_comm_prepared) { return ; }
200
+
201
+ // First, we need to split the matrix into two parts, a square matrix
202
+ // for pure local operations and another part for remote operations.
203
+
204
+ Long all_nnz = m_data.nnz ;
205
+ Long local_nnz;
206
+ Gpu::DeviceVector<Long> pfsum (all_nnz);
207
+ auto * p_pfsum = pfsum.data ();
208
+ auto row_begin = m_row_begin;
209
+ auto row_end = m_row_end;
210
+ if (m_data.nnz < Long (std::numeric_limits<int >::max ())) {
211
+ auto const * pcol = m_data.col_index .data ();
212
+ local_nnz = Scan::PrefixSum<Long>(int (all_nnz),
213
+ [=] AMREX_GPU_DEVICE (int i) -> Long {
214
+ return (pcol[i] >= row_begin &&
215
+ pcol[i] < row_end); },
216
+ [=] AMREX_GPU_DEVICE (int i, Long const & x) {
217
+ p_pfsum[i] = x; },
218
+ Scan::Type::exclusive, Scan::retSum);
219
+ } else {
220
+ auto const * pcol = m_data.col_index .data ();
221
+ local_nnz = Scan::PrefixSum<Long>(all_nnz,
222
+ [=] AMREX_GPU_DEVICE (Long i) -> Long {
223
+ return (pcol[i] >= row_begin &&
224
+ pcol[i] < row_end); },
225
+ [=] AMREX_GPU_DEVICE (Long i, Long const & x) {
226
+ p_pfsum[i] = x; },
227
+ Scan::Type::exclusive, Scan::retSum);
228
+ }
229
+
230
+ m_data.nnz = local_nnz;
231
+ Long remote_nnz = all_nnz - local_nnz;
232
+ m_data_remote.nnz = remote_nnz;
233
+
234
+ if (local_nnz != all_nnz) {
235
+ m_data_remote.mat .resize (remote_nnz);
236
+ m_data_remote.col_index .resize (remote_nnz);
237
+ Vec<T> new_mat (local_nnz);
238
+ Vec<Long> new_col (local_nnz);
239
+ auto const * pmat = m_data.mat .data ();
240
+ auto const * pcol = m_data.col_index .data ();
241
+ auto * pmat_l = new_mat.data ();
242
+ auto * pcol_l = new_col.data ();
243
+ auto * pmat_r = m_data_remote.mat .data ();
244
+ auto * pcol_r = m_data_remote.col_index .data ();
245
+ ParallelFor (all_nnz, [=] AMREX_GPU_DEVICE (Long i)
246
+ {
247
+ auto ps = p_pfsum[i];
248
+ auto local = (pcol[i] >= row_begin &&
249
+ pcol[i] < row_end);
250
+ if (local) {
251
+ pmat_l[ps] = pmat[i];
252
+ pcol_l[ps] = pcol[i] - row_begin; // shift the column index to local
253
+ } else {
254
+ pmat_r[i-ps] = pmat[i];
255
+ pcol_r[i-ps] = pcol[i];
256
+ }
257
+ });
258
+ auto noffset = Long (m_data.row_offset .size ());
259
+ auto * pro = m_data.row_offset .data ();
260
+ ParallelFor (noffset, [=] AMREX_GPU_DEVICE (Long i)
261
+ {
262
+ if (i < noffset-1 ) {
263
+ pro[i] = p_pfsum[pro[i]];
264
+ } else {
265
+ pro[i] = local_nnz;
266
+ }
267
+ });
268
+ Gpu::streamSynchronize ();
269
+ m_data.mat .swap (new_mat);
270
+ m_data.col_index .swap (new_col);
271
+
272
+ // xxxxx TODO: still need to work on m_data_remote
273
+ }
274
+
275
+ m_comm_prepared = true ;
276
+ }
277
+
278
+ #endif
279
+
152
280
}
153
281
154
282
#endif
0 commit comments