Skip to content

Commit 75083a5

Browse files
committed
FFT: Add batch support
1 parent 823ec7f commit 75083a5

10 files changed

+419
-180
lines changed

Src/FFT/AMReX_FFT_Helper.H

+59-40
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ namespace amrex::FFT
4747

4848
enum struct Direction { forward, backward, both, none };
4949

50-
enum struct DomainStrategy { slab, pencil };
50+
enum struct DomainStrategy { automatic, slab, pencil };
5151

5252
AMREX_ENUM( Boundary, periodic, even, odd );
5353

@@ -56,15 +56,28 @@ enum struct Kind { none, r2c_f, r2c_b, c2c_f, c2c_b, r2r_ee_f, r2r_ee_b,
5656

5757
struct Info
5858
{
59-
//! Supported only in 3D. When batch_mode is true, FFT is performed on
59+
//! Domain composition strategy.
60+
DomainStrategy domain_strategy = DomainStrategy::automatic;
61+
62+
//! For automatic strategy, this is the size per process we switch from
63+
//! slab to pencil.
64+
int pencil_threshold = 12;
65+
66+
//! Supported only in 3D. When twod_mode is true, FFT is performed on
6067
//! the first two dimensions only and the third dimension size is the
6168
//! batch size.
62-
bool batch_mode = false;
69+
bool twod_mode = false;
70+
71+
//! Batched FFT size. Only support in R2C, not R2X.
72+
int batch_size = 1;
6373

6474
//! Max number of processes to use
6575
int nprocs = std::numeric_limits<int>::max();
6676

67-
Info& setBatchMode (bool x) { batch_mode = x; return *this; }
77+
Info& setDomainStrategy (DomainStrategy s) { domain_strategy = s; return *this; }
78+
Info& setPencilThreshold (int t) { pencil_threshold = t; return *this; }
79+
Info& setTwoDMode (bool x) { twod_mode = x; return *this; }
80+
Info& setBatchSize (int bsize) { batch_size = bsize; return *this; }
6881
Info& setNumProcs (int n) { nprocs = n; return *this; }
6982
};
7083

@@ -170,7 +183,7 @@ struct Plan
170183
}
171184

172185
template <Direction D>
173-
void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false)
186+
void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false, int ncomp = 1)
174187
{
175188
static_assert(D == Direction::forward || D == Direction::backward);
176189

@@ -198,6 +211,7 @@ struct Plan
198211
howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2))
199212
: AMREX_D_TERM(1, *1 , *box.length(2));
200213
#endif
214+
howmany *= ncomp;
201215

202216
amrex::ignore_unused(nc);
203217

@@ -293,10 +307,10 @@ struct Plan
293307
}
294308

295309
template <Direction D, int M>
296-
void init_r2c (IntVectND<M> const& fft_size, void*, void*, bool cache);
310+
void init_r2c (IntVectND<M> const& fft_size, void*, void*, bool cache, int ncomp = 1);
297311

298312
template <Direction D>
299-
void init_c2c (Box const& box, VendorComplex* p)
313+
void init_c2c (Box const& box, VendorComplex* p, int ncomp = 1)
300314
{
301315
static_assert(D == Direction::forward || D == Direction::backward);
302316

@@ -307,6 +321,7 @@ struct Plan
307321

308322
n = box.length(0);
309323
howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
324+
howmany *= ncomp;
310325

311326
#if defined(AMREX_USE_CUDA)
312327
AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
@@ -1131,7 +1146,7 @@ struct Plan
11311146
}
11321147
};
11331148

1134-
using Key = std::tuple<IntVectND<3>,Direction,Kind>;
1149+
using Key = std::tuple<IntVectND<3>,int,Direction,Kind>;
11351150
using PlanD = typename Plan<double>::VendorPlan;
11361151
using PlanF = typename Plan<float>::VendorPlan;
11371152

@@ -1143,7 +1158,7 @@ void add_vendor_plan_f (Key const& key, PlanF plan);
11431158

11441159
template <typename T>
11451160
template <Direction D, int M>
1146-
void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool cache)
1161+
void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool cache, int ncomp)
11471162
{
11481163
static_assert(D == Direction::forward || D == Direction::backward);
11491164

@@ -1154,10 +1169,10 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool
11541169

11551170
n = 1;
11561171
for (auto s : fft_size) { n *= s; }
1157-
howmany = 1;
1172+
howmany = ncomp;
11581173

11591174
#if defined(AMREX_USE_GPU)
1160-
Key key = {fft_size.template expand<3>(), D, kind};
1175+
Key key = {fft_size.template expand<3>(), ncomp, D, kind};
11611176
if (cache) {
11621177
VendorPlan* cached_plan = nullptr;
11631178
if constexpr (std::is_same_v<float,T>) {
@@ -1174,27 +1189,34 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool
11741189
amrex::ignore_unused(cache);
11751190
#endif
11761191

1192+
int len[M];
1193+
for (int i = 0; i < M; ++i) {
1194+
len[i] = fft_size[M-1-i];
1195+
}
1196+
1197+
int nc = fft_size[0]/2+1;
1198+
for (int i = 1; i < M; ++i) {
1199+
nc *= fft_size[i];
1200+
}
1201+
11771202
#if defined(AMREX_USE_CUDA)
11781203

11791204
AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
11801205
AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
11811206
cufftType type;
1207+
int n_in, n_out;
11821208
if constexpr (D == Direction::forward) {
11831209
type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
1210+
n_in = n;
1211+
n_out = nc;
11841212
} else {
11851213
type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
1214+
n_in = nc;
1215+
n_out = n;
11861216
}
11871217
std::size_t work_size;
1188-
if constexpr (M == 1) {
1189-
AMREX_CUFFT_SAFE_CALL
1190-
(cufftMakePlan1d(plan, fft_size[0], type, howmany, &work_size));
1191-
} else if constexpr (M == 2) {
1192-
AMREX_CUFFT_SAFE_CALL
1193-
(cufftMakePlan2d(plan, fft_size[1], fft_size[0], type, &work_size));
1194-
} else if constexpr (M == 3) {
1195-
AMREX_CUFFT_SAFE_CALL
1196-
(cufftMakePlan3d(plan, fft_size[2], fft_size[1], fft_size[0], type, &work_size));
1197-
}
1218+
AMREX_CUFFT_SAFE_CALL
1219+
(cufftMakePlanMany(plan, M, len, nullptr, 1, n_in, nullptr, 1, n_out, type, howmany, &work_size));
11981220

11991221
#elif defined(AMREX_USE_HIP)
12001222

@@ -1219,19 +1241,21 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool
12191241
if (M == 1) {
12201242
pp = new mkl_desc_r(fft_size[0]);
12211243
} else {
1222-
std::vector<std::int64_t> len(M);
1244+
std::vector<std::int64_t> len64(M);
12231245
for (int idim = 0; idim < M; ++idim) {
1224-
len[idim] = fft_size[M-1-idim];
1246+
len64[idim] = len[idim];
12251247
}
1226-
pp = new mkl_desc_r(len);
1248+
pp = new mkl_desc_r(len64);
12271249
}
12281250
#ifndef AMREX_USE_MKL_DFTI_2024
12291251
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
12301252
oneapi::mkl::dft::config_value::NOT_INPLACE);
12311253
#else
12321254
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
12331255
#endif
1234-
1256+
pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
1257+
pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
1258+
pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
12351259
std::vector<std::int64_t> strides(M+1);
12361260
strides[0] = 0;
12371261
strides[M] = 1;
@@ -1258,29 +1282,24 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool
12581282
return;
12591283
}
12601284

1261-
int size_for_row_major[M];
1262-
for (int idim = 0; idim < M; ++idim) {
1263-
size_for_row_major[idim] = fft_size[M-1-idim];
1264-
}
1265-
12661285
if constexpr (std::is_same_v<float,T>) {
12671286
if constexpr (D == Direction::forward) {
1268-
plan = fftwf_plan_dft_r2c
1269-
(M, size_for_row_major, (float*)pf, (fftwf_complex*)pb,
1287+
plan = fftwf_plan_many_dft_r2c
1288+
(M, len, howmany, (float*)pf, nullptr, 1, n, (fftwf_complex*)pb, nullptr, 1, nc,
12701289
FFTW_ESTIMATE);
12711290
} else {
1272-
plan = fftwf_plan_dft_c2r
1273-
(M, size_for_row_major, (fftwf_complex*)pb, (float*)pf,
1291+
plan = fftwf_plan_many_dft_c2r
1292+
(M, len, howmany, (fftwf_complex*)pb, nullptr, 1, nc, (float*)pf, nullptr, 1, n,
12741293
FFTW_ESTIMATE);
12751294
}
12761295
} else {
12771296
if constexpr (D == Direction::forward) {
1278-
plan = fftw_plan_dft_r2c
1279-
(M, size_for_row_major, (double*)pf, (fftw_complex*)pb,
1297+
plan = fftw_plan_many_dft_r2c
1298+
(M, len, howmany, (double*)pf, nullptr, 1, n, (fftw_complex*)pb, nullptr, 1, nc,
12801299
FFTW_ESTIMATE);
12811300
} else {
1282-
plan = fftw_plan_dft_c2r
1283-
(M, size_for_row_major, (fftw_complex*)pb, (double*)pf,
1301+
plan = fftw_plan_many_dft_c2r
1302+
(M, len, howmany, (fftw_complex*)pb, nullptr, 1, nc, (double*)pf, nullptr, 1, n,
12841303
FFTW_ESTIMATE);
12851304
}
12861305
}
@@ -1508,10 +1527,10 @@ namespace detail
15081527
b = make_box(b);
15091528
}
15101529
auto const& ng = make_iv(mf.nGrowVect());
1511-
FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), 1, ng, MFInfo{}.SetAlloc(false));
1530+
FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), mf.nComp(), ng, MFInfo{}.SetAlloc(false));
15121531
using FAB = typename FA::fab_type;
15131532
for (MFIter mfi(submf, MFItInfo().DisableDeviceSync()); mfi.isValid(); ++mfi) {
1514-
submf.setFab(mfi, FAB(mfi.fabbox(), 1, mf[mfi].dataPtr()));
1533+
submf.setFab(mfi, FAB(mfi.fabbox(), mf.nComp(), mf[mfi].dataPtr()));
15151534
}
15161535
return submf;
15171536
}

Src/FFT/AMReX_FFT_OpenBCSolver.H

+12-11
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Box OpenBCSolver<T>::make_grown_domain (Box const& domain, Info const& info)
3737
{
3838
IntVect len = domain.length();
3939
#if (AMREX_SPACEDIM == 3)
40-
if (info.batch_mode) { len[2] = 0; }
40+
if (info.twod_mode) { len[2] = 0; }
4141
#else
4242
amrex::ignore_unused(info);
4343
#endif
@@ -48,18 +48,19 @@ template <typename T>
4848
OpenBCSolver<T>::OpenBCSolver (Box const& domain, Info const& info)
4949
: m_domain(domain),
5050
m_info(info),
51-
m_r2c(OpenBCSolver<T>::make_grown_domain(domain,info), info)
51+
m_r2c(OpenBCSolver<T>::make_grown_domain(domain,info),
52+
m_info.setDomainStrategy(FFT::DomainStrategy::slab))
5253
{
5354
#if (AMREX_SPACEDIM == 3)
54-
if (m_info.batch_mode) {
55+
if (m_info.twod_mode) {
5556
auto gdom = make_grown_domain(domain,m_info);
5657
gdom.enclosedCells(2);
5758
gdom.setSmall(2, 0);
5859
int nprocs = std::min({ParallelContext::NProcsSub(),
5960
m_info.nprocs,
6061
m_domain.length(2)});
6162
gdom.setBig(2, nprocs-1);
62-
m_r2c_green = std::make_unique<R2C<T>>(gdom,info);
63+
m_r2c_green = std::make_unique<R2C<T>>(gdom,m_info);
6364
auto [sd, ord] = m_r2c_green->getSpectralData();
6465
m_G_fft = cMF(*sd, amrex::make_alias, 0, 1);
6566
} else
@@ -78,7 +79,7 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
7879
{
7980
BL_PROFILE("OpenBCSolver::setGreensFunction");
8081

81-
auto* infab = m_info.batch_mode ? detail::get_fab(m_r2c_green->m_rx)
82+
auto* infab = m_info.twod_mode ? detail::get_fab(m_r2c_green->m_rx)
8283
: detail::get_fab(m_r2c.m_rx);
8384
auto const& lo = m_domain.smallEnd();
8485
auto const& lo3 = lo.dim3();
@@ -87,7 +88,7 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
8788
auto const& a = infab->array();
8889
auto box = infab->box();
8990
GpuArray<int,3> nimages{1,1,1};
90-
int ndims = m_info.batch_mode ? AMREX_SPACEDIM-1 : AMREX_SPACEDIM;
91+
int ndims = m_info.twod_mode ? AMREX_SPACEDIM-1 : AMREX_SPACEDIM;
9192
for (int idim = 0; idim < ndims; ++idim) {
9293
if (box.smallEnd(idim) == lo[idim] && box.length(idim) == 2*len[idim]) {
9394
box.growHi(idim, -len[idim]+1); // +1 to include the middle plane
@@ -129,13 +130,13 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
129130
});
130131
}
131132

132-
if (m_info.batch_mode) {
133+
if (m_info.twod_mode) {
133134
m_r2c_green->forward(m_r2c_green->m_rx);
134135
} else {
135136
m_r2c.forward(m_r2c.m_rx);
136137
}
137138

138-
if (!m_info.batch_mode) {
139+
if (!m_info.twod_mode) {
139140
auto [sd, ord] = m_r2c.getSpectralData();
140141
amrex::ignore_unused(ord);
141142
auto const* srcfab = detail::get_fab(*sd);
@@ -166,7 +167,7 @@ void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
166167
inmf.setVal(T(0));
167168
inmf.ParallelCopy(rho, 0, 0, 1);
168169

169-
m_r2c.m_openbc_half = !m_info.batch_mode;
170+
m_r2c.m_openbc_half = !m_info.twod_mode;
170171
m_r2c.forward(inmf);
171172
m_r2c.m_openbc_half = false;
172173

@@ -183,7 +184,7 @@ void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
183184
Box const& rhobox = rhofab->box();
184185
#if (AMREX_SPACEDIM == 3)
185186
Long leng = gfab->box().numPts();
186-
if (m_info.batch_mode) {
187+
if (m_info.twod_mode) {
187188
AMREX_ASSERT(gfab->box().length(2) == 1 &&
188189
leng == (rhobox.length(0) * rhobox.length(1)));
189190
} else {
@@ -204,7 +205,7 @@ void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
204205
}
205206
}
206207

207-
m_r2c.m_openbc_half = !m_info.batch_mode;
208+
m_r2c.m_openbc_half = !m_info.twod_mode;
208209
m_r2c.backward_doit(phi, phi.nGrowVect());
209210
m_r2c.m_openbc_half = false;
210211
}

Src/FFT/AMReX_FFT_Poisson.H

+2-2
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ public:
127127
}
128128
}
129129
Info info{};
130-
info.setBatchMode(true);
130+
info.setTwoDMode(true);
131131
if (periodic_xy) {
132132
m_r2c = std::make_unique<R2C<typename MF::value_type>>(m_geom.Domain(),
133133
info);
@@ -145,7 +145,7 @@ public:
145145
std::make_pair(Boundary::periodic,Boundary::periodic),
146146
std::make_pair(Boundary::even,Boundary::even))},
147147
m_r2c(std::make_unique<R2C<typename MF::value_type>>
148-
(geom.Domain(), Info().setBatchMode(true)))
148+
(geom.Domain(), Info().setTwoDMode(true)))
149149
{
150150
#if (AMREX_SPACEDIM == 3)
151151
AMREX_ALWAYS_ASSERT(geom.isPeriodic(0) && geom.isPeriodic(1));

0 commit comments

Comments
 (0)