diff --git a/common/src/KokkosFFT_HIP_types.hpp b/common/src/KokkosFFT_HIP_types.hpp index 8729f1e7..fe1fa574 100644 --- a/common/src/KokkosFFT_HIP_types.hpp +++ b/common/src/KokkosFFT_HIP_types.hpp @@ -48,11 +48,11 @@ struct FFTDataType { hipfftDoubleComplex, fftw_complex>; }; -template +template struct FFTPlanType { - using fftwHandle = - std::conditional_t, float>, - fftwf_plan, fftw_plan>; + using fftwHandle = std::conditional_t< + std::is_same_v, float>, fftwf_plan, + fftw_plan>; using type = std::conditional_t, hipfftHandle, fftwHandle>; }; @@ -151,7 +151,7 @@ struct FFTDataType { using complex128 = hipfftDoubleComplex; }; -template +template struct FFTPlanType { using type = hipfftHandle; }; diff --git a/fft/src/KokkosFFT_HIP_plans.hpp b/fft/src/KokkosFFT_HIP_plans.hpp index 7d494d32..fc3d67e6 100644 --- a/fft/src/KokkosFFT_HIP_plans.hpp +++ b/fft/src/KokkosFFT_HIP_plans.hpp @@ -13,7 +13,7 @@ template , std::nullptr_t> = nullptr> -auto _create(const ExecutionSpace& exec_space, PlanType& plan, +auto _create(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, const OutViewType& out, [[maybe_unused]] FFTDirectionType direction) { static_assert(Kokkos::is_view::value, @@ -23,7 +23,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; - hipfftResult hipfft_rt = hipfftCreate(&plan); + plan = std::make_unique(); + hipfftResult hipfft_rt = hipfftCreate(&(*plan)); if (hipfft_rt != HIPFFT_SUCCESS) throw std::runtime_error("hipfftCreate failed"); @@ -38,7 +39,7 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, std::multiplies<>()); - hipfft_rt = hipfftPlan1d(&plan, nx, type, batch); + hipfft_rt = hipfftPlan1d(&(*plan), nx, type, batch); if (hipfft_rt != HIPFFT_SUCCESS) throw std::runtime_error("hipfftPlan1d failed"); return fft_size; @@ -50,7 +51,7 @@ template , std::nullptr_t> = nullptr> -auto _create(const ExecutionSpace& exec_space, PlanType& plan, +auto _create(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, const OutViewType& out, [[maybe_unused]] FFTDirectionType direction) { static_assert(Kokkos::is_view::value, @@ -60,7 +61,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; - hipfftResult hipfft_rt = hipfftCreate(&plan); + plan = std::make_unique(); + hipfftResult hipfft_rt = hipfftCreate(&(*plan)); if (hipfft_rt != HIPFFT_SUCCESS) throw std::runtime_error("hipfftCreate failed"); @@ -73,7 +75,7 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, std::multiplies<>()); - hipfft_rt = hipfftPlan2d(&plan, nx, ny, type); + hipfft_rt = hipfftPlan2d(&(*plan), nx, ny, type); if (hipfft_rt != HIPFFT_SUCCESS) throw std::runtime_error("hipfftPlan2d failed"); return fft_size; @@ -85,7 +87,7 @@ template , std::nullptr_t> = nullptr> -auto _create(const ExecutionSpace& exec_space, PlanType& plan, +auto _create(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, const OutViewType& out, [[maybe_unused]] FFTDirectionType direction) { static_assert(Kokkos::is_view::value, @@ -95,7 +97,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; - hipfftResult hipfft_rt = hipfftCreate(&plan); + plan = std::make_unique(); + hipfftResult hipfft_rt = hipfftCreate(&(*plan)); if (hipfft_rt != HIPFFT_SUCCESS) throw std::runtime_error("hipfftCreate failed"); @@ -112,7 +115,7 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, std::multiplies<>()); - hipfft_rt = hipfftPlan3d(&plan, nx, ny, nz, type); + hipfft_rt = hipfftPlan3d(&(*plan), nx, ny, nz, type); if (hipfft_rt != HIPFFT_SUCCESS) throw std::runtime_error("hipfftPlan3d failed"); return fft_size; @@ -124,7 +127,7 @@ template , std::nullptr_t> = nullptr> -auto _create(const ExecutionSpace& exec_space, PlanType& plan, +auto _create(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, const OutViewType& out, [[maybe_unused]] FFTDirectionType direction) { static_assert(Kokkos::is_view::value, @@ -134,7 +137,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; - hipfftResult hipfft_rt = hipfftCreate(&plan); + plan = std::make_unique(); + hipfftResult hipfft_rt = hipfftCreate(&(*plan)); if (hipfft_rt != HIPFFT_SUCCESS) throw std::runtime_error("hipfftCreate failed"); @@ -152,8 +156,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, std::multiplies<>()); - hipfft_rt = hipfftPlanMany(&plan, rank, fft_extents.data(), nullptr, 1, idist, - nullptr, 1, odist, type, batch); + hipfft_rt = hipfftPlanMany(&(*plan), rank, fft_extents.data(), nullptr, 1, + idist, nullptr, 1, odist, type, batch); if (hipfft_rt != HIPFFT_SUCCESS) throw std::runtime_error("hipfftPlanMany failed"); return fft_size; @@ -165,7 +169,7 @@ template , std::nullptr_t> = nullptr> -auto _create(const ExecutionSpace& exec_space, PlanType& plan, +auto _create(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, const OutViewType& out, [[maybe_unused]] FFTDirectionType direction, axis_type axes) { @@ -195,25 +199,25 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan, // For the moment, considering the contiguous layout only int istride = 1, ostride = 1; - hipfftResult hipfft_rt = hipfftCreate(&plan); + plan = std::make_unique(); + hipfftResult hipfft_rt = hipfftCreate(&(*plan)); if (hipfft_rt != HIPFFT_SUCCESS) throw std::runtime_error("hipfftCreate failed"); - hipfft_rt = hipfftPlanMany(&plan, rank, fft_extents.data(), in_extents.data(), - istride, idist, out_extents.data(), ostride, odist, - type, howmany); + hipfft_rt = hipfftPlanMany(&(*plan), rank, fft_extents.data(), + in_extents.data(), istride, idist, + out_extents.data(), ostride, odist, type, howmany); if (hipfft_rt != HIPFFT_SUCCESS) throw std::runtime_error("hipfftPlan failed"); return fft_size; } -template , std::nullptr_t> = nullptr> -void _destroy( - typename KokkosFFT::Impl::FFTPlanType::type& plan) { - hipfftDestroy(plan); +void _destroy(std::unique_ptr& plan) { + hipfftDestroy(*plan); } } // namespace Impl } // namespace KokkosFFT diff --git a/fft/src/KokkosFFT_HIP_transform.hpp b/fft/src/KokkosFFT_HIP_transform.hpp index dd63fbbb..30c1785a 100644 --- a/fft/src/KokkosFFT_HIP_transform.hpp +++ b/fft/src/KokkosFFT_HIP_transform.hpp @@ -5,42 +5,42 @@ namespace KokkosFFT { namespace Impl { -void _exec(hipfftHandle plan, hipfftReal* idata, hipfftComplex* odata, +void _exec(hipfftHandle& plan, hipfftReal* idata, hipfftComplex* odata, [[maybe_unused]] int direction) { hipfftResult hipfft_rt = hipfftExecR2C(plan, idata, odata); if (hipfft_rt != HIPFFT_SUCCESS) throw std::runtime_error("hipfftExecR2C failed"); } -void _exec(hipfftHandle plan, hipfftDoubleReal* idata, +void _exec(hipfftHandle& plan, hipfftDoubleReal* idata, hipfftDoubleComplex* odata, [[maybe_unused]] int direction) { hipfftResult hipfft_rt = hipfftExecD2Z(plan, idata, odata); if (hipfft_rt != HIPFFT_SUCCESS) throw std::runtime_error("hipfftExecD2Z failed"); } -void _exec(hipfftHandle plan, hipfftComplex* idata, hipfftReal* odata, +void _exec(hipfftHandle& plan, hipfftComplex* idata, hipfftReal* odata, [[maybe_unused]] int direction) { hipfftResult hipfft_rt = hipfftExecC2R(plan, idata, odata); if (hipfft_rt != HIPFFT_SUCCESS) throw std::runtime_error("hipfftExecC2R failed"); } -void _exec(hipfftHandle plan, hipfftDoubleComplex* idata, +void _exec(hipfftHandle& plan, hipfftDoubleComplex* idata, hipfftDoubleReal* odata, [[maybe_unused]] int direction) { hipfftResult hipfft_rt = hipfftExecZ2D(plan, idata, odata); if (hipfft_rt != HIPFFT_SUCCESS) throw std::runtime_error("hipfftExecZ2D failed"); } -void _exec(hipfftHandle plan, hipfftComplex* idata, hipfftComplex* odata, +void _exec(hipfftHandle& plan, hipfftComplex* idata, hipfftComplex* odata, int direction) { hipfftResult hipfft_rt = hipfftExecC2C(plan, idata, odata, direction); if (hipfft_rt != HIPFFT_SUCCESS) throw std::runtime_error("hipfftExecC2C failed"); } -void _exec(hipfftHandle plan, hipfftDoubleComplex* idata, +void _exec(hipfftHandle& plan, hipfftDoubleComplex* idata, hipfftDoubleComplex* odata, int direction) { hipfftResult hipfft_rt = hipfftExecZ2Z(plan, idata, odata, direction); if (hipfft_rt != HIPFFT_SUCCESS)