@@ -47,7 +47,7 @@ namespace amrex::FFT
47
47
48
48
enum struct Direction { forward, backward, both, none };
49
49
50
- enum struct DomainStrategy { slab, pencil };
50
+ enum struct DomainStrategy { automatic, slab, pencil };
51
51
52
52
AMREX_ENUM ( Boundary, periodic, even, odd );
53
53
@@ -56,15 +56,28 @@ enum struct Kind { none, r2c_f, r2c_b, c2c_f, c2c_b, r2r_ee_f, r2r_ee_b,
56
56
57
57
struct Info
58
58
{
59
- // ! Supported only in 3D. When batch_mode is true, FFT is performed on
59
+ // ! Domain composition strategy.
60
+ DomainStrategy domain_strategy = DomainStrategy::automatic;
61
+
62
+ // ! For automatic strategy, this is the size per process we switch from
63
+ // ! slab to pencil.
64
+ int pencil_threshold = 12 ;
65
+
66
+ // ! Supported only in 3D. When twod_mode is true, FFT is performed on
60
67
// ! the first two dimensions only and the third dimension size is the
61
68
// ! batch size.
62
- bool batch_mode = false ;
69
+ bool twod_mode = false ;
70
+
71
+ // ! Batched FFT size. Only support in R2C, not R2X.
72
+ int batch_size = 1 ;
63
73
64
74
// ! Max number of processes to use
65
75
int nprocs = std::numeric_limits<int >::max();
66
76
67
- Info& setBatchMode (bool x) { batch_mode = x; return *this ; }
77
+ Info& setDomainStrategy (DomainStrategy s) { domain_strategy = s; return *this ; }
78
+ Info& setPencilThreshold (int t) { pencil_threshold = t; return *this ; }
79
+ Info& setTwoDMode (bool x) { twod_mode = x; return *this ; }
80
+ Info& setBatchSize (int bsize) { batch_size = bsize; return *this ; }
68
81
Info& setNumProcs (int n) { nprocs = n; return *this ; }
69
82
};
70
83
@@ -170,7 +183,7 @@ struct Plan
170
183
}
171
184
172
185
template <Direction D>
173
- void init_r2c (Box const & box, T* pr, VendorComplex* pc, bool is_2d_transform = false )
186
+ void init_r2c (Box const & box, T* pr, VendorComplex* pc, bool is_2d_transform = false , int ncomp = 1 )
174
187
{
175
188
static_assert (D == Direction::forward || D == Direction::backward);
176
189
@@ -198,6 +211,7 @@ struct Plan
198
211
howmany = (rank == 1 ) ? AMREX_D_TERM (1 , *box.length (1 ), *box.length (2 ))
199
212
: AMREX_D_TERM (1 , *1 , *box.length (2 ));
200
213
#endif
214
+ howmany *= ncomp;
201
215
202
216
amrex::ignore_unused (nc);
203
217
@@ -293,10 +307,10 @@ struct Plan
293
307
}
294
308
295
309
template <Direction D, int M>
296
- void init_r2c (IntVectND<M> const & fft_size, void *, void *, bool cache);
310
+ void init_r2c (IntVectND<M> const & fft_size, void *, void *, bool cache, int ncomp = 1 );
297
311
298
312
template <Direction D>
299
- void init_c2c (Box const & box, VendorComplex* p)
313
+ void init_c2c (Box const & box, VendorComplex* p, int ncomp = 1 )
300
314
{
301
315
static_assert (D == Direction::forward || D == Direction::backward);
302
316
@@ -307,6 +321,7 @@ struct Plan
307
321
308
322
n = box.length (0 );
309
323
howmany = AMREX_D_TERM (1 , *box.length (1 ), *box.length (2 ));
324
+ howmany *= ncomp;
310
325
311
326
#if defined(AMREX_USE_CUDA)
312
327
AMREX_CUFFT_SAFE_CALL (cufftCreate (&plan));
@@ -1131,7 +1146,7 @@ struct Plan
1131
1146
}
1132
1147
};
1133
1148
1134
- using Key = std::tuple<IntVectND<3 >,Direction,Kind>;
1149
+ using Key = std::tuple<IntVectND<3 >,int , Direction,Kind>;
1135
1150
using PlanD = typename Plan<double >::VendorPlan;
1136
1151
using PlanF = typename Plan<float >::VendorPlan;
1137
1152
@@ -1143,7 +1158,7 @@ void add_vendor_plan_f (Key const& key, PlanF plan);
1143
1158
1144
1159
template <typename T>
1145
1160
template <Direction D, int M>
1146
- void Plan<T>::init_r2c (IntVectND<M> const & fft_size, void * pbf, void * pbb, bool cache)
1161
+ void Plan<T>::init_r2c (IntVectND<M> const & fft_size, void * pbf, void * pbb, bool cache, int ncomp )
1147
1162
{
1148
1163
static_assert (D == Direction::forward || D == Direction::backward);
1149
1164
@@ -1154,10 +1169,10 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool
1154
1169
1155
1170
n = 1 ;
1156
1171
for (auto s : fft_size) { n *= s; }
1157
- howmany = 1 ;
1172
+ howmany = ncomp ;
1158
1173
1159
1174
#if defined(AMREX_USE_GPU)
1160
- Key key = {fft_size.template expand <3 >(), D, kind};
1175
+ Key key = {fft_size.template expand <3 >(), ncomp, D, kind};
1161
1176
if (cache) {
1162
1177
VendorPlan* cached_plan = nullptr ;
1163
1178
if constexpr (std::is_same_v<float ,T>) {
@@ -1174,27 +1189,34 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool
1174
1189
amrex::ignore_unused (cache);
1175
1190
#endif
1176
1191
1192
+ int len[M];
1193
+ for (int i = 0 ; i < M; ++i) {
1194
+ len[i] = fft_size[M-1 -i];
1195
+ }
1196
+
1197
+ int nc = fft_size[0 ]/2 +1 ;
1198
+ for (int i = 1 ; i < M; ++i) {
1199
+ nc *= fft_size[i];
1200
+ }
1201
+
1177
1202
#if defined(AMREX_USE_CUDA)
1178
1203
1179
1204
AMREX_CUFFT_SAFE_CALL (cufftCreate (&plan));
1180
1205
AMREX_CUFFT_SAFE_CALL (cufftSetAutoAllocation (plan, 0 ));
1181
1206
cufftType type;
1207
+ int n_in, n_out;
1182
1208
if constexpr (D == Direction::forward) {
1183
1209
type = std::is_same_v<float ,T> ? CUFFT_R2C : CUFFT_D2Z;
1210
+ n_in = n;
1211
+ n_out = nc;
1184
1212
} else {
1185
1213
type = std::is_same_v<float ,T> ? CUFFT_C2R : CUFFT_Z2D;
1214
+ n_in = nc;
1215
+ n_out = n;
1186
1216
}
1187
1217
std::size_t work_size;
1188
- if constexpr (M == 1 ) {
1189
- AMREX_CUFFT_SAFE_CALL
1190
- (cufftMakePlan1d (plan, fft_size[0 ], type, howmany, &work_size));
1191
- } else if constexpr (M == 2 ) {
1192
- AMREX_CUFFT_SAFE_CALL
1193
- (cufftMakePlan2d (plan, fft_size[1 ], fft_size[0 ], type, &work_size));
1194
- } else if constexpr (M == 3 ) {
1195
- AMREX_CUFFT_SAFE_CALL
1196
- (cufftMakePlan3d (plan, fft_size[2 ], fft_size[1 ], fft_size[0 ], type, &work_size));
1197
- }
1218
+ AMREX_CUFFT_SAFE_CALL
1219
+ (cufftMakePlanMany (plan, M, len, nullptr , 1 , n_in, nullptr , 1 , n_out, type, howmany, &work_size));
1198
1220
1199
1221
#elif defined(AMREX_USE_HIP)
1200
1222
@@ -1219,19 +1241,21 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool
1219
1241
if (M == 1 ) {
1220
1242
pp = new mkl_desc_r (fft_size[0 ]);
1221
1243
} else {
1222
- std::vector<std::int64_t > len (M);
1244
+ std::vector<std::int64_t > len64 (M);
1223
1245
for (int idim = 0 ; idim < M; ++idim) {
1224
- len [idim] = fft_size[M- 1 - idim];
1246
+ len64 [idim] = len[ idim];
1225
1247
}
1226
- pp = new mkl_desc_r (len );
1248
+ pp = new mkl_desc_r (len64 );
1227
1249
}
1228
1250
#ifndef AMREX_USE_MKL_DFTI_2024
1229
1251
pp->set_value (oneapi::mkl::dft::config_param::PLACEMENT,
1230
1252
oneapi::mkl::dft::config_value::NOT_INPLACE);
1231
1253
#else
1232
1254
pp->set_value (oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
1233
1255
#endif
1234
-
1256
+ pp->set_value (oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
1257
+ pp->set_value (oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
1258
+ pp->set_value (oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
1235
1259
std::vector<std::int64_t > strides (M+1 );
1236
1260
strides[0 ] = 0 ;
1237
1261
strides[M] = 1 ;
@@ -1258,29 +1282,24 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool
1258
1282
return ;
1259
1283
}
1260
1284
1261
- int size_for_row_major[M];
1262
- for (int idim = 0 ; idim < M; ++idim) {
1263
- size_for_row_major[idim] = fft_size[M-1 -idim];
1264
- }
1265
-
1266
1285
if constexpr (std::is_same_v<float ,T>) {
1267
1286
if constexpr (D == Direction::forward) {
1268
- plan = fftwf_plan_dft_r2c
1269
- (M, size_for_row_major, (float *)pf, (fftwf_complex*)pb,
1287
+ plan = fftwf_plan_many_dft_r2c
1288
+ (M, len, howmany, (float *)pf, nullptr , 1 , n, (fftwf_complex*)pb, nullptr , 1 , nc ,
1270
1289
FFTW_ESTIMATE);
1271
1290
} else {
1272
- plan = fftwf_plan_dft_c2r
1273
- (M, size_for_row_major, (fftwf_complex*)pb, (float *)pf,
1291
+ plan = fftwf_plan_many_dft_c2r
1292
+ (M, len, howmany, (fftwf_complex*)pb, nullptr , 1 , nc, (float *)pf, nullptr , 1 , n ,
1274
1293
FFTW_ESTIMATE);
1275
1294
}
1276
1295
} else {
1277
1296
if constexpr (D == Direction::forward) {
1278
- plan = fftw_plan_dft_r2c
1279
- (M, size_for_row_major, (double *)pf, (fftw_complex*)pb,
1297
+ plan = fftw_plan_many_dft_r2c
1298
+ (M, len, howmany, (double *)pf, nullptr , 1 , n, (fftw_complex*)pb, nullptr , 1 , nc ,
1280
1299
FFTW_ESTIMATE);
1281
1300
} else {
1282
- plan = fftw_plan_dft_c2r
1283
- (M, size_for_row_major, (fftw_complex*)pb, (double *)pf,
1301
+ plan = fftw_plan_many_dft_c2r
1302
+ (M, len, howmany, (fftw_complex*)pb, nullptr , 1 , nc, (double *)pf, nullptr , 1 , n ,
1284
1303
FFTW_ESTIMATE);
1285
1304
}
1286
1305
}
@@ -1508,10 +1527,10 @@ namespace detail
1508
1527
b = make_box (b);
1509
1528
}
1510
1529
auto const & ng = make_iv (mf.nGrowVect ());
1511
- FA submf (BoxArray (std::move (bl)), mf.DistributionMap (), 1 , ng, MFInfo{}.SetAlloc (false ));
1530
+ FA submf (BoxArray (std::move (bl)), mf.DistributionMap (), mf. nComp () , ng, MFInfo{}.SetAlloc (false ));
1512
1531
using FAB = typename FA::fab_type;
1513
1532
for (MFIter mfi (submf, MFItInfo ().DisableDeviceSync ()); mfi.isValid (); ++mfi) {
1514
- submf.setFab (mfi, FAB (mfi.fabbox (), 1 , mf[mfi].dataPtr ()));
1533
+ submf.setFab (mfi, FAB (mfi.fabbox (), mf. nComp () , mf[mfi].dataPtr ()));
1515
1534
}
1516
1535
return submf;
1517
1536
}
0 commit comments