@@ -61,7 +61,7 @@ struct Info
61
61
62
62
// ! For automatic strategy, this is the size per process below which we
63
63
// ! switch from slab to pencil.
64
- int pencil_threshold = 8 ;
64
+ int pencil_threshold = 4 ;
65
65
66
66
// ! Supported only in 3D. When twod_mode is true, FFT is performed on
67
67
// ! the first two dimensions only and the third dimension size is the
@@ -310,7 +310,7 @@ struct Plan
310
310
void init_r2c (IntVectND<M> const & fft_size, void *, void *, bool cache, int ncomp = 1 );
311
311
312
312
template <Direction D>
313
- void init_c2c (Box const & box, VendorComplex* p, int ncomp = 1 )
313
+ void init_c2c (Box const & box, VendorComplex* p, int ncomp = 1 , int ndims = 1 )
314
314
{
315
315
static_assert (D == Direction::forward || D == Direction::backward);
316
316
@@ -319,9 +319,35 @@ struct Plan
319
319
pf = (void *)p;
320
320
pb = (void *)p;
321
321
322
- n = box.length (0 );
323
- howmany = AMREX_D_TERM (1 , *box.length (1 ), *box.length (2 ));
324
- howmany *= ncomp;
322
+ int len[3 ];
323
+
324
+ if (ndims == 1 ) {
325
+ n = box.length (0 );
326
+ howmany = AMREX_D_TERM (1 , *box.length (1 ), *box.length (2 ));
327
+ howmany *= ncomp;
328
+ len[0 ] = box.length (0 );
329
+ }
330
+ #if (AMREX_SPACEDIM >= 2)
331
+ else if (ndims == 2 ) {
332
+ n = box.length (0 ) * box.length (1 );
333
+ #if (AMREX_SPACEDIM == 2)
334
+ howmany = ncomp;
335
+ #else
336
+ howmany = box.length (2 ) * ncomp;
337
+ #endif
338
+ len[0 ] = box.length (1 );
339
+ len[1 ] = box.length (0 );
340
+ }
341
+ #if (AMREX_SPACEDIM == 3)
342
+ else if (ndims == 3 ) {
343
+ n = box.length (0 ) * box.length (1 ) * box.length (2 );
344
+ howmany = ncomp;
345
+ len[0 ] = box.length (2 );
346
+ len[1 ] = box.length (1 );
347
+ len[2 ] = box.length (0 );
348
+ }
349
+ #endif
350
+ #endif
325
351
326
352
#if defined(AMREX_USE_CUDA)
327
353
AMREX_CUFFT_SAFE_CALL (cufftCreate (&plan));
@@ -330,22 +356,39 @@ struct Plan
330
356
cufftType t = std::is_same_v<float ,T> ? CUFFT_C2C : CUFFT_Z2Z;
331
357
std::size_t work_size;
332
358
AMREX_CUFFT_SAFE_CALL
333
- (cufftMakePlanMany (plan, 1 , &n , nullptr , 1 , n, nullptr , 1 , n, t, howmany, &work_size));
359
+ (cufftMakePlanMany (plan, ndims, len , nullptr , 1 , n, nullptr , 1 , n, t, howmany, &work_size));
334
360
335
361
#elif defined(AMREX_USE_HIP)
336
362
337
363
auto prec = std::is_same_v<float ,T> ? rocfft_precision_single
338
364
: rocfft_precision_double;
339
365
auto dir= (D == Direction::forward) ? rocfft_transform_type_complex_forward
340
366
: rocfft_transform_type_complex_inverse;
341
- const std::size_t length = n;
367
+ std::size_t length[3 ];
368
+ if (ndims == 1 ) {
369
+ length[0 ] = len[0 ];
370
+ } else if (ndims == 2 ) {
371
+ length[0 ] = len[1 ];
372
+ length[1 ] = len[0 ];
373
+ } else {
374
+ length[0 ] = len[2 ];
375
+ length[1 ] = len[1 ];
376
+ length[2 ] = len[0 ];
377
+ }
342
378
AMREX_ROCFFT_SAFE_CALL
343
- (rocfft_plan_create (&plan, rocfft_placement_inplace, dir, prec, 1 ,
344
- & length, howmany, nullptr ));
379
+ (rocfft_plan_create (&plan, rocfft_placement_inplace, dir, prec, ndims ,
380
+ length, howmany, nullptr ));
345
381
346
382
#elif defined(AMREX_USE_SYCL)
347
383
348
- auto * pp = new mkl_desc_c (n);
384
+ mkl_desc_c* pp;
385
+ if (ndims == 1 ) {
386
+ pp = new mkl_desc_c (n);
387
+ } else if (ndims == 2 ) {
388
+ pp = new mkl_desc_c ({std::int64_t (len[0 ]), std::int64_t (len[1 ])});
389
+ } else {
390
+ pp = new mkl_desc_c ({std::int64_t (len[0 ]), std::int64_t (len[1 ]), std::int64_t (len[2 ])});
391
+ }
349
392
#ifndef AMREX_USE_MKL_DFTI_2024
350
393
pp->set_value (oneapi::mkl::dft::config_param::PLACEMENT,
351
394
oneapi::mkl::dft::config_value::INPLACE);
@@ -355,7 +398,12 @@ struct Plan
355
398
pp->set_value (oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
356
399
pp->set_value (oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
357
400
pp->set_value (oneapi::mkl::dft::config_param::BWD_DISTANCE, n);
358
- std::vector<std::int64_t > strides = {0 ,1 };
401
+ std::vector<std::int64_t > strides (ndims+1 );
402
+ strides[0 ] = 0 ;
403
+ strides[ndims] = 1 ;
404
+ for (int i = ndims-1 ; i >= 1 ; --i) {
405
+ strides[i] = strides[i+1 ] * len[ndims-1 -i];
406
+ }
359
407
#ifndef AMREX_USE_MKL_DFTI_2024
360
408
pp->set_value (oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
361
409
pp->set_value (oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
@@ -373,21 +421,21 @@ struct Plan
373
421
if constexpr (std::is_same_v<float ,T>) {
374
422
if constexpr (D == Direction::forward) {
375
423
plan = fftwf_plan_many_dft
376
- (1 , &n , howmany, p, nullptr , 1 , n, p, nullptr , 1 , n, -1 ,
424
+ (ndims, len , howmany, p, nullptr , 1 , n, p, nullptr , 1 , n, -1 ,
377
425
FFTW_ESTIMATE);
378
426
} else {
379
427
plan = fftwf_plan_many_dft
380
- (1 , &n , howmany, p, nullptr , 1 , n, p, nullptr , 1 , n, +1 ,
428
+ (ndims, len , howmany, p, nullptr , 1 , n, p, nullptr , 1 , n, +1 ,
381
429
FFTW_ESTIMATE);
382
430
}
383
431
} else {
384
432
if constexpr (D == Direction::forward) {
385
433
plan = fftw_plan_many_dft
386
- (1 , &n , howmany, p, nullptr , 1 , n, p, nullptr , 1 , n, -1 ,
434
+ (ndims, len , howmany, p, nullptr , 1 , n, p, nullptr , 1 , n, -1 ,
387
435
FFTW_ESTIMATE);
388
436
} else {
389
437
plan = fftw_plan_many_dft
390
- (1 , &n , howmany, p, nullptr , 1 , n, p, nullptr , 1 , n, +1 ,
438
+ (ndims, len , howmany, p, nullptr , 1 , n, p, nullptr , 1 , n, +1 ,
391
439
FFTW_ESTIMATE);
392
440
}
393
441
}
0 commit comments