Skip to content

Commit

Permalink
Fix CUDA backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Jan 17, 2024
1 parent 90997e7 commit f867556
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 33 deletions.
10 changes: 5 additions & 5 deletions common/src/KokkosFFT_Cuda_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ struct FFTDataType {
cufftDoubleComplex, fftw_complex>;
};

template <typename ExecutionSpace, typename T>
template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using fftwHandle =
std::conditional_t<std::is_same_v<KokkosFFT::Impl::real_type_t<T>, float>,
fftwf_plan, fftw_plan>;
using fftwHandle = std::conditional_t<
std::is_same_v<KokkosFFT::Impl::real_type_t<T1>, float>, fftwf_plan,
fftw_plan>;
using type = std::conditional_t<std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
cufftHandle, fftwHandle>;
};
Expand Down Expand Up @@ -151,7 +151,7 @@ struct FFTDataType {
using complex128 = cufftDoubleComplex;
};

template <typename ExecutionSpace, typename T>
template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using type = cufftHandle;
};
Expand Down
48 changes: 26 additions & 22 deletions fft/src/KokkosFFT_Cuda_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ template <typename ExecutionSpace, typename PlanType, typename InViewType,
std::enable_if_t<InViewType::rank() == 1 &&
std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
auto _create(const ExecutionSpace& exec_space, PlanType& plan,
auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] Direction direction) {
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -23,7 +23,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

cufftResult cufft_rt = cufftCreate(&plan);
plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftCreate(&(*plan));
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed");

const int batch = 1;
Expand All @@ -37,7 +38,7 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());

cufft_rt = cufftPlan1d(&plan, nx, type, batch);
cufft_rt = cufftPlan1d(&(*plan), nx, type, batch);
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftPlan1d failed");
return fft_size;
}
Expand All @@ -48,7 +49,7 @@ template <typename ExecutionSpace, typename PlanType, typename InViewType,
std::enable_if_t<InViewType::rank() == 2 &&
std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
auto _create(const ExecutionSpace& exec_space, PlanType& plan,
auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] Direction direction) {
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -58,7 +59,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

cufftResult cufft_rt = cufftCreate(&plan);
plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftCreate(&(*plan));
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed");

const int axis = 0;
Expand All @@ -70,7 +72,7 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());

cufft_rt = cufftPlan2d(&plan, nx, ny, type);
cufft_rt = cufftPlan2d(&(*plan), nx, ny, type);
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftPlan2d failed");
return fft_size;
}
Expand All @@ -81,7 +83,7 @@ template <typename ExecutionSpace, typename PlanType, typename InViewType,
std::enable_if_t<InViewType::rank() == 3 &&
std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
auto _create(const ExecutionSpace& exec_space, PlanType& plan,
auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] Direction direction) {
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -91,7 +93,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

cufftResult cufft_rt = cufftCreate(&plan);
plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftCreate(&(*plan));
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed");

const int axis = 0;
Expand All @@ -106,7 +109,7 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());

cufft_rt = cufftPlan3d(&plan, nx, ny, nz, type);
cufft_rt = cufftPlan3d(&(*plan), nx, ny, nz, type);
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftPlan3d failed");
return fft_size;
}
Expand All @@ -117,7 +120,7 @@ template <typename ExecutionSpace, typename PlanType, typename InViewType,
std::enable_if_t<std::isgreater(InViewType::rank(), 3) &&
std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
auto _create(const ExecutionSpace& exec_space, PlanType& plan,
auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] Direction direction) {
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -127,7 +130,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

cufftResult cufft_rt = cufftCreate(&plan);
plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftCreate(&(*plan));
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed");

const int rank = InViewType::rank();
Expand All @@ -144,8 +148,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());

cufft_rt = cufftPlanMany(&plan, rank, fft_extents.data(), nullptr, 1, idist,
nullptr, 1, odist, type, batch);
cufft_rt = cufftPlanMany(&(*plan), rank, fft_extents.data(), nullptr, 1,
idist, nullptr, 1, odist, type, batch);
if (cufft_rt != CUFFT_SUCCESS)
throw std::runtime_error("cufftPlanMany failed");
return fft_size;
Expand All @@ -156,7 +160,7 @@ template <typename ExecutionSpace, typename PlanType, typename InViewType,
typename OutViewType, std::size_t fft_rank = 1,
std::enable_if_t<std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
auto _create(const ExecutionSpace& exec_space, PlanType& plan,
auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes) {
static_assert(Kokkos::is_view<InViewType>::value,
Expand Down Expand Up @@ -185,24 +189,24 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
// For the moment, considering the contiguous layout only
int istride = 1, ostride = 1;

cufftResult cufft_rt = cufftCreate(&plan);
plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftCreate(&(*plan));
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed");

cufft_rt =
cufftPlanMany(&plan, rank, fft_extents.data(), in_extents.data(), istride,
idist, out_extents.data(), ostride, odist, type, howmany);
cufft_rt = cufftPlanMany(&(*plan), rank, fft_extents.data(),
in_extents.data(), istride, idist,
out_extents.data(), ostride, odist, type, howmany);
if (cufft_rt != CUFFT_SUCCESS)
throw std::runtime_error("cufftPlanMany failed");

return fft_size;
}

template <typename ExecutionSpace, typename T,
template <typename ExecutionSpace, typename PlanType,
std::enable_if_t<std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
void _destroy(
typename KokkosFFT::Impl::FFTPlanType<ExecutionSpace, T>::type& plan) {
cufftDestroy(plan);
void _destroy(std::unique_ptr<PlanType>& plan) {
cufftDestroy(*plan);
}
} // namespace Impl
} // namespace KokkosFFT
Expand Down
12 changes: 6 additions & 6 deletions fft/src/KokkosFFT_Cuda_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,42 @@

namespace KokkosFFT {
namespace Impl {
void _exec(cufftHandle plan, cufftReal* idata, cufftComplex* odata,
void _exec(cufftHandle& plan, cufftReal* idata, cufftComplex* odata,
[[maybe_unused]] int direction) {
cufftResult cufft_rt = cufftExecR2C(plan, idata, odata);
if (cufft_rt != CUFFT_SUCCESS)
throw std::runtime_error("cufftExecR2C failed");
}

void _exec(cufftHandle plan, cufftDoubleReal* idata, cufftDoubleComplex* odata,
void _exec(cufftHandle& plan, cufftDoubleReal* idata, cufftDoubleComplex* odata,
[[maybe_unused]] int direction) {
cufftResult cufft_rt = cufftExecD2Z(plan, idata, odata);
if (cufft_rt != CUFFT_SUCCESS)
throw std::runtime_error("cufftExecD2Z failed");
}

void _exec(cufftHandle plan, cufftComplex* idata, cufftReal* odata,
void _exec(cufftHandle& plan, cufftComplex* idata, cufftReal* odata,
[[maybe_unused]] int direction) {
cufftResult cufft_rt = cufftExecC2R(plan, idata, odata);
if (cufft_rt != CUFFT_SUCCESS)
throw std::runtime_error("cufftExecC2R failed");
}

void _exec(cufftHandle plan, cufftDoubleComplex* idata, cufftDoubleReal* odata,
void _exec(cufftHandle& plan, cufftDoubleComplex* idata, cufftDoubleReal* odata,
[[maybe_unused]] int direction) {
cufftResult cufft_rt = cufftExecZ2D(plan, idata, odata);
if (cufft_rt != CUFFT_SUCCESS)
throw std::runtime_error("cufftExecZ2D failed");
}

void _exec(cufftHandle plan, cufftComplex* idata, cufftComplex* odata,
void _exec(cufftHandle& plan, cufftComplex* idata, cufftComplex* odata,
int direction) {
cufftResult cufft_rt = cufftExecC2C(plan, idata, odata, direction);
if (cufft_rt != CUFFT_SUCCESS)
throw std::runtime_error("cufftExecC2C failed");
}

void _exec(cufftHandle plan, cufftDoubleComplex* idata,
void _exec(cufftHandle& plan, cufftDoubleComplex* idata,
cufftDoubleComplex* odata, int direction) {
cufftResult cufft_rt = cufftExecZ2Z(plan, idata, odata, direction);
if (cufft_rt != CUFFT_SUCCESS)
Expand Down

0 comments on commit f867556

Please sign in to comment.