Skip to content

Commit

Permalink
Fix HIP backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Jan 17, 2024
1 parent f867556 commit 92c0ad3
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 33 deletions.
10 changes: 5 additions & 5 deletions common/src/KokkosFFT_HIP_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ struct FFTDataType {
hipfftDoubleComplex, fftw_complex>;
};

template <typename ExecutionSpace, typename T>
template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using fftwHandle =
std::conditional_t<std::is_same_v<KokkosFFT::Impl::real_type_t<T>, float>,
fftwf_plan, fftw_plan>;
using fftwHandle = std::conditional_t<
std::is_same_v<KokkosFFT::Impl::real_type_t<T1>, float>, fftwf_plan,
fftw_plan>;
using type = std::conditional_t<std::is_same_v<ExecutionSpace, Kokkos::HIP>,
hipfftHandle, fftwHandle>;
};
Expand Down Expand Up @@ -151,7 +151,7 @@ struct FFTDataType {
using complex128 = hipfftDoubleComplex;
};

template <typename ExecutionSpace, typename T>
template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using type = hipfftHandle;
};
Expand Down
48 changes: 26 additions & 22 deletions fft/src/KokkosFFT_HIP_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ template <typename ExecutionSpace, typename PlanType, typename InViewType,
std::enable_if_t<InViewType::rank() == 1 &&
std::is_same_v<ExecutionSpace, Kokkos::HIP>,
std::nullptr_t> = nullptr>
auto _create(const ExecutionSpace& exec_space, PlanType& plan,
auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] FFTDirectionType direction) {
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -23,7 +23,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

hipfftResult hipfft_rt = hipfftCreate(&plan);
plan = std::make_unique<PlanType>();
hipfftResult hipfft_rt = hipfftCreate(&(*plan));
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftCreate failed");

Expand All @@ -38,7 +39,7 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());

hipfft_rt = hipfftPlan1d(&plan, nx, type, batch);
hipfft_rt = hipfftPlan1d(&(*plan), nx, type, batch);
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftPlan1d failed");
return fft_size;
Expand All @@ -50,7 +51,7 @@ template <typename ExecutionSpace, typename PlanType, typename InViewType,
std::enable_if_t<InViewType::rank() == 2 &&
std::is_same_v<ExecutionSpace, Kokkos::HIP>,
std::nullptr_t> = nullptr>
auto _create(const ExecutionSpace& exec_space, PlanType& plan,
auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] FFTDirectionType direction) {
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -60,7 +61,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

hipfftResult hipfft_rt = hipfftCreate(&plan);
plan = std::make_unique<PlanType>();
hipfftResult hipfft_rt = hipfftCreate(&(*plan));
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftCreate failed");

Expand All @@ -73,7 +75,7 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());

hipfft_rt = hipfftPlan2d(&plan, nx, ny, type);
hipfft_rt = hipfftPlan2d(&(*plan), nx, ny, type);
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftPlan2d failed");
return fft_size;
Expand All @@ -85,7 +87,7 @@ template <typename ExecutionSpace, typename PlanType, typename InViewType,
std::enable_if_t<InViewType::rank() == 3 &&
std::is_same_v<ExecutionSpace, Kokkos::HIP>,
std::nullptr_t> = nullptr>
auto _create(const ExecutionSpace& exec_space, PlanType& plan,
auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] FFTDirectionType direction) {
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -95,7 +97,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

hipfftResult hipfft_rt = hipfftCreate(&plan);
plan = std::make_unique<PlanType>();
hipfftResult hipfft_rt = hipfftCreate(&(*plan));
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftCreate failed");

Expand All @@ -112,7 +115,7 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());

hipfft_rt = hipfftPlan3d(&plan, nx, ny, nz, type);
hipfft_rt = hipfftPlan3d(&(*plan), nx, ny, nz, type);
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftPlan3d failed");
return fft_size;
Expand All @@ -124,7 +127,7 @@ template <typename ExecutionSpace, typename PlanType, typename InViewType,
std::enable_if_t<std::isgreater(InViewType::rank(), 3) &&
std::is_same_v<ExecutionSpace, Kokkos::HIP>,
std::nullptr_t> = nullptr>
auto _create(const ExecutionSpace& exec_space, PlanType& plan,
auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] FFTDirectionType direction) {
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -134,7 +137,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

hipfftResult hipfft_rt = hipfftCreate(&plan);
plan = std::make_unique<PlanType>();
hipfftResult hipfft_rt = hipfftCreate(&(*plan));
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftCreate failed");

Expand All @@ -152,8 +156,8 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());

hipfft_rt = hipfftPlanMany(&plan, rank, fft_extents.data(), nullptr, 1, idist,
nullptr, 1, odist, type, batch);
hipfft_rt = hipfftPlanMany(&(*plan), rank, fft_extents.data(), nullptr, 1,
idist, nullptr, 1, odist, type, batch);
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftPlanMany failed");
return fft_size;
Expand All @@ -165,7 +169,7 @@ template <typename ExecutionSpace, typename PlanType, typename InViewType,
std::size_t fft_rank = 1,
std::enable_if_t<std::is_same_v<ExecutionSpace, Kokkos::HIP>,
std::nullptr_t> = nullptr>
auto _create(const ExecutionSpace& exec_space, PlanType& plan,
auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] FFTDirectionType direction,
axis_type<fft_rank> axes) {
Expand Down Expand Up @@ -195,25 +199,25 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
// For the moment, considering the contiguous layout only
int istride = 1, ostride = 1;

hipfftResult hipfft_rt = hipfftCreate(&plan);
plan = std::make_unique<PlanType>();
hipfftResult hipfft_rt = hipfftCreate(&(*plan));
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftCreate failed");

hipfft_rt = hipfftPlanMany(&plan, rank, fft_extents.data(), in_extents.data(),
istride, idist, out_extents.data(), ostride, odist,
type, howmany);
hipfft_rt = hipfftPlanMany(&(*plan), rank, fft_extents.data(),
in_extents.data(), istride, idist,
out_extents.data(), ostride, odist, type, howmany);

if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftPlan failed");
return fft_size;
}

template <typename ExecutionSpace, typename T,
template <typename ExecutionSpace, typename PlanType,
std::enable_if_t<std::is_same_v<ExecutionSpace, Kokkos::HIP>,
std::nullptr_t> = nullptr>
void _destroy(
typename KokkosFFT::Impl::FFTPlanType<ExecutionSpace, T>::type& plan) {
hipfftDestroy(plan);
void _destroy(std::unique_ptr<PlanType>& plan) {
hipfftDestroy(*plan);
}
} // namespace Impl
} // namespace KokkosFFT
Expand Down
12 changes: 6 additions & 6 deletions fft/src/KokkosFFT_HIP_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,42 @@

namespace KokkosFFT {
namespace Impl {
void _exec(hipfftHandle plan, hipfftReal* idata, hipfftComplex* odata,
void _exec(hipfftHandle& plan, hipfftReal* idata, hipfftComplex* odata,
[[maybe_unused]] int direction) {
hipfftResult hipfft_rt = hipfftExecR2C(plan, idata, odata);
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftExecR2C failed");
}

void _exec(hipfftHandle plan, hipfftDoubleReal* idata,
void _exec(hipfftHandle& plan, hipfftDoubleReal* idata,
hipfftDoubleComplex* odata, [[maybe_unused]] int direction) {
hipfftResult hipfft_rt = hipfftExecD2Z(plan, idata, odata);
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftExecD2Z failed");
}

void _exec(hipfftHandle plan, hipfftComplex* idata, hipfftReal* odata,
void _exec(hipfftHandle& plan, hipfftComplex* idata, hipfftReal* odata,
[[maybe_unused]] int direction) {
hipfftResult hipfft_rt = hipfftExecC2R(plan, idata, odata);
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftExecC2R failed");
}

void _exec(hipfftHandle plan, hipfftDoubleComplex* idata,
void _exec(hipfftHandle& plan, hipfftDoubleComplex* idata,
hipfftDoubleReal* odata, [[maybe_unused]] int direction) {
hipfftResult hipfft_rt = hipfftExecZ2D(plan, idata, odata);
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftExecZ2D failed");
}

void _exec(hipfftHandle plan, hipfftComplex* idata, hipfftComplex* odata,
void _exec(hipfftHandle& plan, hipfftComplex* idata, hipfftComplex* odata,
int direction) {
hipfftResult hipfft_rt = hipfftExecC2C(plan, idata, odata, direction);
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftExecC2C failed");
}

void _exec(hipfftHandle plan, hipfftDoubleComplex* idata,
void _exec(hipfftHandle& plan, hipfftDoubleComplex* idata,
hipfftDoubleComplex* odata, int direction) {
hipfftResult hipfft_rt = hipfftExecZ2Z(plan, idata, odata, direction);
if (hipfft_rt != HIPFFT_SUCCESS)
Expand Down

0 comments on commit 92c0ad3

Please sign in to comment.