Skip to content

Commit ee7a868

Browse files
committed
FFT: Support complex to complex
1 parent 2f64d71 commit ee7a868

File tree

6 files changed

+403
-103
lines changed

6 files changed

+403
-103
lines changed

Docs/sphinx_documentation/source/FFT.rst

+8
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,14 @@ in an :cpp:`FFT::Info` object passed to the constructor of
9393

9494
r2c.backward(cmf, mf);
9595

96+
.. _sec:FFT:c2c:
97+
98+
FFT::C2C Class
99+
==============
100+
101+
:cpp:`FFT::C2C` is a class template that supports complex to complex Fourier
102+
transforms. It has a similar interface as :cpp:`FFT::R2C`.
103+
96104
.. _sec:FFT:localr2c:
97105

98106
FFT::LocalR2C Class

Src/FFT/AMReX_FFT_Helper.H

+63-15
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ struct Info
6161

6262
//! For automatic strategy, this is the size per process below which we
6363
//! switch from slab to pencil.
64-
int pencil_threshold = 8;
64+
int pencil_threshold = 4;
6565

6666
//! Supported only in 3D. When twod_mode is true, FFT is performed on
6767
//! the first two dimensions only and the third dimension size is the
@@ -310,7 +310,7 @@ struct Plan
310310
void init_r2c (IntVectND<M> const& fft_size, void*, void*, bool cache, int ncomp = 1);
311311

312312
template <Direction D>
313-
void init_c2c (Box const& box, VendorComplex* p, int ncomp = 1)
313+
void init_c2c (Box const& box, VendorComplex* p, int ncomp = 1, int ndims = 1)
314314
{
315315
static_assert(D == Direction::forward || D == Direction::backward);
316316

@@ -319,9 +319,35 @@ struct Plan
319319
pf = (void*)p;
320320
pb = (void*)p;
321321

322-
n = box.length(0);
323-
howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
324-
howmany *= ncomp;
322+
int len[3];
323+
324+
if (ndims == 1) {
325+
n = box.length(0);
326+
howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
327+
howmany *= ncomp;
328+
len[0] = box.length(0);
329+
}
330+
#if (AMREX_SPACEDIM >= 2)
331+
else if (ndims == 2) {
332+
n = box.length(0) * box.length(1);
333+
#if (AMREX_SPACEDIM == 2)
334+
howmany = ncomp;
335+
#else
336+
howmany = box.length(2) * ncomp;
337+
#endif
338+
len[0] = box.length(1);
339+
len[1] = box.length(0);
340+
}
341+
#if (AMREX_SPACEDIM == 3)
342+
else if (ndims == 3) {
343+
n = box.length(0) * box.length(1) * box.length(2);
344+
howmany = ncomp;
345+
len[0] = box.length(2);
346+
len[1] = box.length(1);
347+
len[2] = box.length(0);
348+
}
349+
#endif
350+
#endif
325351

326352
#if defined(AMREX_USE_CUDA)
327353
AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
@@ -330,22 +356,39 @@ struct Plan
330356
cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
331357
std::size_t work_size;
332358
AMREX_CUFFT_SAFE_CALL
333-
(cufftMakePlanMany(plan, 1, &n, nullptr, 1, n, nullptr, 1, n, t, howmany, &work_size));
359+
(cufftMakePlanMany(plan, ndims, len, nullptr, 1, n, nullptr, 1, n, t, howmany, &work_size));
334360

335361
#elif defined(AMREX_USE_HIP)
336362

337363
auto prec = std::is_same_v<float,T> ? rocfft_precision_single
338364
: rocfft_precision_double;
339365
auto dir= (D == Direction::forward) ? rocfft_transform_type_complex_forward
340366
: rocfft_transform_type_complex_inverse;
341-
const std::size_t length = n;
367+
std::size_t length[3];
368+
if (ndims == 1) {
369+
length[0] = len[0];
370+
} else if (ndims == 2) {
371+
length[0] = len[1];
372+
length[1] = len[0];
373+
} else {
374+
length[0] = len[2];
375+
length[1] = len[1];
376+
length[2] = len[0];
377+
}
342378
AMREX_ROCFFT_SAFE_CALL
343-
(rocfft_plan_create(&plan, rocfft_placement_inplace, dir, prec, 1,
344-
&length, howmany, nullptr));
379+
(rocfft_plan_create(&plan, rocfft_placement_inplace, dir, prec, ndims,
380+
length, howmany, nullptr));
345381

346382
#elif defined(AMREX_USE_SYCL)
347383

348-
auto* pp = new mkl_desc_c(n);
384+
mkl_desc_c* pp;
385+
if (ndims == 1) {
386+
pp = new mkl_desc_c(n);
387+
} else if (ndims == 2) {
388+
pp = new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1])});
389+
} else {
390+
pp = new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1]), std::int64_t(len[2])});
391+
}
349392
#ifndef AMREX_USE_MKL_DFTI_2024
350393
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
351394
oneapi::mkl::dft::config_value::INPLACE);
@@ -355,7 +398,12 @@ struct Plan
355398
pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
356399
pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
357400
pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, n);
358-
std::vector<std::int64_t> strides = {0,1};
401+
std::vector<std::int64_t> strides(ndims+1);
402+
strides[0] = 0;
403+
strides[ndims] = 1;
404+
for (int i = ndims-1; i >= 1; --i) {
405+
strides[i] = strides[i+1] * len[ndims-1-i];
406+
}
359407
#ifndef AMREX_USE_MKL_DFTI_2024
360408
pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
361409
pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
@@ -373,21 +421,21 @@ struct Plan
373421
if constexpr (std::is_same_v<float,T>) {
374422
if constexpr (D == Direction::forward) {
375423
plan = fftwf_plan_many_dft
376-
(1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
424+
(ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
377425
FFTW_ESTIMATE);
378426
} else {
379427
plan = fftwf_plan_many_dft
380-
(1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
428+
(ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
381429
FFTW_ESTIMATE);
382430
}
383431
} else {
384432
if constexpr (D == Direction::forward) {
385433
plan = fftw_plan_many_dft
386-
(1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
434+
(ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
387435
FFTW_ESTIMATE);
388436
} else {
389437
plan = fftw_plan_many_dft
390-
(1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
438+
(ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
391439
FFTW_ESTIMATE);
392440
}
393441
}

0 commit comments

Comments
 (0)