diff --git a/common/src/KokkosFFT_Cuda_types.hpp b/common/src/KokkosFFT_Cuda_types.hpp index 7234a791..c6e577e6 100644 --- a/common/src/KokkosFFT_Cuda_types.hpp +++ b/common/src/KokkosFFT_Cuda_types.hpp @@ -48,11 +48,11 @@ struct FFTDataType { cufftDoubleComplex, fftw_complex>; }; -template +template struct FFTPlanType { - using fftwHandle = - std::conditional_t, float>, - fftwf_plan, fftw_plan>; + using fftwHandle = std::conditional_t< + std::is_same_v, float>, fftwf_plan, + fftw_plan>; using type = std::conditional_t, cufftHandle, fftwHandle>; }; @@ -151,7 +151,7 @@ struct FFTDataType { using complex128 = cufftDoubleComplex; }; -template +template struct FFTPlanType { using type = cufftHandle; }; diff --git a/fft/src/KokkosFFT_Cuda_plans.hpp b/fft/src/KokkosFFT_Cuda_plans.hpp index e78dd43c..8d4bc7bd 100644 --- a/fft/src/KokkosFFT_Cuda_plans.hpp +++ b/fft/src/KokkosFFT_Cuda_plans.hpp @@ -13,7 +13,7 @@ template , std::nullptr_t> = nullptr> -auto _create(const ExecutionSpace& exec_space, PlanType& plan, +auto _create(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, const OutViewType& out, [[maybe_unused]] Direction direction) { static_assert(Kokkos::is_view::value, @@ -23,7 +23,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; - cufftResult cufft_rt = cufftCreate(&plan); + plan = std::make_unique(); + cufftResult cufft_rt = cufftCreate(&(*plan)); if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed"); const int batch = 1; @@ -37,7 +38,7 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, std::multiplies<>()); - cufft_rt = cufftPlan1d(&plan, nx, type, batch); + cufft_rt = cufftPlan1d(&(*plan), nx, type, batch); if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftPlan1d failed"); return fft_size; } @@ -48,7 +49,7 @@ template , std::nullptr_t> = nullptr> -auto _create(const ExecutionSpace& exec_space, PlanType& plan, +auto _create(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, const OutViewType& out, [[maybe_unused]] Direction direction) { static_assert(Kokkos::is_view::value, @@ -58,7 +59,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; - cufftResult cufft_rt = cufftCreate(&plan); + plan = std::make_unique(); + cufftResult cufft_rt = cufftCreate(&(*plan)); if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed"); const int axis = 0; @@ -70,7 +72,7 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, std::multiplies<>()); - cufft_rt = cufftPlan2d(&plan, nx, ny, type); + cufft_rt = cufftPlan2d(&(*plan), nx, ny, type); if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftPlan2d failed"); return fft_size; } @@ -81,7 +83,7 @@ template , std::nullptr_t> = nullptr> -auto _create(const ExecutionSpace& exec_space, PlanType& plan, +auto _create(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, const OutViewType& out, [[maybe_unused]] Direction direction) { static_assert(Kokkos::is_view::value, @@ -91,7 +93,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; - cufftResult cufft_rt = cufftCreate(&plan); + plan = std::make_unique(); + cufftResult cufft_rt = cufftCreate(&(*plan)); if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed"); const int axis = 0; @@ -106,7 +109,7 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, std::multiplies<>()); - cufft_rt = cufftPlan3d(&plan, nx, ny, nz, type); + cufft_rt = cufftPlan3d(&(*plan), nx, ny, nz, type); if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftPlan3d failed"); return fft_size; } @@ -117,7 +120,7 @@ template , std::nullptr_t> = nullptr> -auto _create(const ExecutionSpace& exec_space, PlanType& plan, +auto _create(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, const OutViewType& out, [[maybe_unused]] Direction direction) { static_assert(Kokkos::is_view::value, @@ -127,7 +130,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; - cufftResult cufft_rt = cufftCreate(&plan); + plan = std::make_unique(); + cufftResult cufft_rt = cufftCreate(&(*plan)); if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed"); const int rank = InViewType::rank(); @@ -144,8 +148,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, std::multiplies<>()); - cufft_rt = cufftPlanMany(&plan, rank, fft_extents.data(), nullptr, 1, idist, - nullptr, 1, odist, type, batch); + cufft_rt = cufftPlanMany(&(*plan), rank, fft_extents.data(), nullptr, 1, + idist, nullptr, 1, odist, type, batch); if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftPlanMany failed"); return fft_size; @@ -156,7 +160,7 @@ template , std::nullptr_t> = nullptr> -auto _create(const ExecutionSpace& exec_space, PlanType& plan, +auto _create(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, const OutViewType& out, [[maybe_unused]] Direction direction, axis_type axes) { static_assert(Kokkos::is_view::value, @@ -185,24 +189,24 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, // For the moment, considering the contiguous layout only int istride = 1, ostride = 1; - cufftResult cufft_rt = cufftCreate(&plan); + plan = std::make_unique(); + cufftResult cufft_rt = cufftCreate(&(*plan)); if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed"); - cufft_rt = - cufftPlanMany(&plan, rank, fft_extents.data(), in_extents.data(), istride, - idist, out_extents.data(), ostride, odist, type, howmany); + cufft_rt = cufftPlanMany(&(*plan), rank, fft_extents.data(), + in_extents.data(), istride, idist, + out_extents.data(), ostride, odist, type, howmany); if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftPlanMany failed"); return fft_size; } -template , std::nullptr_t> = nullptr> -void _destroy( - typename KokkosFFT::Impl::FFTPlanType::type& plan) { - cufftDestroy(plan); +void _destroy(std::unique_ptr& plan) { + cufftDestroy(*plan); } } // namespace Impl } // namespace KokkosFFT diff --git a/fft/src/KokkosFFT_Cuda_transform.hpp b/fft/src/KokkosFFT_Cuda_transform.hpp index 478e16a0..4ba45964 100644 --- a/fft/src/KokkosFFT_Cuda_transform.hpp +++ b/fft/src/KokkosFFT_Cuda_transform.hpp @@ -5,42 +5,42 @@ namespace KokkosFFT { namespace Impl { -void _exec(cufftHandle plan, cufftReal* idata, cufftComplex* odata, +void _exec(cufftHandle& plan, cufftReal* idata, cufftComplex* odata, [[maybe_unused]] int direction) { cufftResult cufft_rt = cufftExecR2C(plan, idata, odata); if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftExecR2C failed"); } -void _exec(cufftHandle plan, cufftDoubleReal* idata, cufftDoubleComplex* odata, +void _exec(cufftHandle& plan, cufftDoubleReal* idata, cufftDoubleComplex* odata, [[maybe_unused]] int direction) { cufftResult cufft_rt = cufftExecD2Z(plan, idata, odata); if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftExecD2Z failed"); } -void _exec(cufftHandle plan, cufftComplex* idata, cufftReal* odata, +void _exec(cufftHandle& plan, cufftComplex* idata, cufftReal* odata, [[maybe_unused]] int direction) { cufftResult cufft_rt = cufftExecC2R(plan, idata, odata); if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftExecC2R failed"); } -void _exec(cufftHandle plan, cufftDoubleComplex* idata, cufftDoubleReal* odata, +void _exec(cufftHandle& plan, cufftDoubleComplex* idata, cufftDoubleReal* odata, [[maybe_unused]] int direction) { cufftResult cufft_rt = cufftExecZ2D(plan, idata, odata); if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftExecZ2D failed"); } -void _exec(cufftHandle plan, cufftComplex* idata, cufftComplex* odata, +void _exec(cufftHandle& plan, cufftComplex* idata, cufftComplex* odata, int direction) { cufftResult cufft_rt = cufftExecC2C(plan, idata, odata, direction); if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftExecC2C failed"); } -void _exec(cufftHandle plan, cufftDoubleComplex* idata, +void _exec(cufftHandle& plan, cufftDoubleComplex* idata, cufftDoubleComplex* odata, int direction) { cufftResult cufft_rt = cufftExecZ2Z(plan, idata, odata, direction); if (cufft_rt != CUFFT_SUCCESS)