Skip to content

Commit

Permalink
Refactor serial tbsv implementation details and tests (#2478)
Browse files Browse the repository at this point in the history
* refactor serial tbsv implementation details and tests

Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>

* fix: test names in SerialTbsv complex

Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>

* use EXPECT_NEAR_KK_REL for comparison

Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>

* use EXPECT_NEAR_KK_REL for general tests

Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>

* Add docstring and assertion for Arg parameters

Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>

---------

Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
Co-authored-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Feb 19, 2025
1 parent b5ec4ab commit 9b4703d
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 180 deletions.
11 changes: 7 additions & 4 deletions batched/dense/impl/KokkosBatched_Pbtrs_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef KOKKOSBATCHED_PBTRS_SERIAL_INTERNAL_HPP_
#define KOKKOSBATCHED_PBTRS_SERIAL_INTERNAL_HPP_

#include "KokkosBlas_util.hpp"
#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Tbsv_Serial_Internal.hpp"

Expand Down Expand Up @@ -50,8 +51,9 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrsInternalLower<Algo::Pbtrs::Unblocked>::inv
SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(false, an, A, as0, as1, x, xs0, kd);

// Solve L**T *X = B, overwriting B with X.
constexpr bool do_conj = Kokkos::ArithTraits<ValueType>::is_complex;
SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(false, do_conj, an, A, as0, as1, x, xs0, kd);
using op =
std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpConj, KokkosBlas::Impl::OpID>;
SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(op(), false, an, A, as0, as1, x, xs0, kd);

return 0;
}
Expand All @@ -76,8 +78,9 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrsInternalUpper<Algo::Pbtrs::Unblocked>::inv
/**/ ValueType *KOKKOS_RESTRICT x,
const int xs0, const int kd) {
// Solve U**T *X = B, overwriting B with X.
constexpr bool do_conj = Kokkos::ArithTraits<ValueType>::is_complex;
SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(false, do_conj, an, A, as0, as1, x, xs0, kd);
using op =
std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpConj, KokkosBlas::Impl::OpID>;
SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(op(), false, an, A, as0, as1, x, xs0, kd);

// Solve U*X = B, overwriting B with X.
SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(false, an, A, as0, as1, x, xs0, kd);
Expand Down
45 changes: 26 additions & 19 deletions batched/dense/impl/KokkosBatched_Tbsv_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@

/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr)

#include "KokkosBlas_util.hpp"
#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Tbsv_Serial_Internal.hpp"

namespace KokkosBatched {

namespace Impl {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int checkTbsvInput([[maybe_unused]] const AViewType &A,
[[maybe_unused]] const XViewType &x, [[maybe_unused]] const int k) {
static_assert(Kokkos::is_view<AViewType>::value, "KokkosBatched::tbsv: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<XViewType>::value, "KokkosBatched::tbsv: XViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::tbsv: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<XViewType>, "KokkosBatched::tbsv: XViewType is not a Kokkos::View.");
static_assert(AViewType::rank == 2, "KokkosBatched::tbsv: AViewType must have rank 2.");
static_assert(XViewType::rank == 1, "KokkosBatched::tbsv: XViewType must have rank 1.");

Expand Down Expand Up @@ -63,15 +64,17 @@ KOKKOS_INLINE_FUNCTION static int checkTbsvInput([[maybe_unused]] const AViewTyp
return 0;
}

} // namespace Impl

//// Lower non-transpose ////
template <typename ArgDiag>
struct SerialTbsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(
return Impl::SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
}
};
Expand All @@ -81,11 +84,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpID(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand All @@ -94,11 +98,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpConj(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand All @@ -107,10 +112,10 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(
return Impl::SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
}
};
Expand All @@ -120,11 +125,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpID(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand All @@ -133,11 +139,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpConj(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand Down
58 changes: 18 additions & 40 deletions batched/dense/impl/KokkosBatched_Tbsv_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
#include "KokkosBatched_Util.hpp"

namespace KokkosBatched {

namespace Impl {
///
/// Serial Internal Impl
/// ====================

///
/// Lower, Non-Transpose
/// Lower
///

template <typename AlgoType>
Expand Down Expand Up @@ -70,49 +70,37 @@ KOKKOS_INLINE_FUNCTION int SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invok

template <typename AlgoType>
struct SerialTbsvInternalLowerTranspose {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int an,
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(Op op, const bool use_unit_diag, const int an,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k);
};

template <>
template <typename ValueType>
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
const bool use_unit_diag, const bool do_conj, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1,
Op op, const bool use_unit_diag, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = an - 1; j >= 0; --j) {
auto temp = x[j * xs0];

if (do_conj) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::min(an - 1, j + k); i > j; --i) {
temp -= Kokkos::ArithTraits<ValueType>::conj(A[(i - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / Kokkos::ArithTraits<ValueType>::conj(A[0 + j * as1]);
} else {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::min(an - 1, j + k); i > j; --i) {
temp -= A[(i - j) * as0 + j * as1] * x[i * xs0];
}
if (!use_unit_diag) temp = temp / A[0 + j * as1];
for (int i = Kokkos::min(an - 1, j + k); i > j; --i) {
temp -= op(A[(i - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / op(A[0 + j * as1]);
x[j * xs0] = temp;
}

return 0;
}

///
/// Upper, Non-Transpose
/// Upper
///

template <typename AlgoType>
Expand Down Expand Up @@ -154,46 +142,36 @@ KOKKOS_INLINE_FUNCTION int SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invok

template <typename AlgoType>
struct SerialTbsvInternalUpperTranspose {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int an,
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(Op op, const bool use_unit_diag, const int an,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k);
};

template <>
template <typename ValueType>
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
const bool use_unit_diag, const bool do_conj, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1,
Op op, const bool use_unit_diag, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < an; j++) {
auto temp = x[j * xs0];
if (do_conj) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::max(0, j - k); i < j; ++i) {
temp -= Kokkos::ArithTraits<ValueType>::conj(A[(i + k - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / Kokkos::ArithTraits<ValueType>::conj(A[k * as0 + j * as1]);
} else {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::max(0, j - k); i < j; ++i) {
temp -= A[(i + k - j) * as0 + j * as1] * x[i * xs0];
}
if (!use_unit_diag) temp = temp / A[k * as0 + j * as1];
for (int i = Kokkos::max(0, j - k); i < j; ++i) {
temp -= op(A[(i + k - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / op(A[k * as0 + j * as1]);
x[j * xs0] = temp;
}

return 0;
}

} // namespace Impl
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_TBSV_SERIAL_INTERNAL_HPP_
19 changes: 19 additions & 0 deletions batched/dense/src/KokkosBatched_Tbsv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ namespace KokkosBatched {
/// non-unit, upper or lower triangular band matrix, with ( k + 1 )
/// diagonals.
///
/// \tparam ArgUplo: Type indicating whether A is the upper (Uplo::Upper) or lower (Uplo::Lower) triangular matrix
/// \tparam ArgTrans: Type indicating the equations to be solved as follows
/// - ArgTrans::NoTranspose: A * X = B
/// - ArgTrans::Transpose: A**T * X = B
/// - ArgTrans::ConjTranspose: A**H * X = B
/// \tparam ArgDiag: Type indicating whether A is the unit (Diag::Unit) or non-unit (Diag::NonUnit) triangular matrix
/// \tparam ArgAlgo: Type indicating the blocked (KokkosBatched::Algo::Tbsv::Blocked) or unblocked
/// (KokkosBatched::Algo::Tbsv::Unblocked) algorithm to be used
///
/// \tparam AViewType: Input type for the matrix, needs to be a 2D view
/// \tparam XViewType: Input type for the right-hand side and the solution,
/// needs to be a 1D view
Expand All @@ -43,6 +52,16 @@ namespace KokkosBatched {

template <typename ArgUplo, typename ArgTrans, typename ArgDiag, typename ArgAlgo>
struct SerialTbsv {
static_assert(
std::is_same_v<ArgUplo, Uplo::Upper> || std::is_same_v<ArgUplo, Uplo::Lower>,
"KokkosBatched::tbsv: Use Uplo::Upper for upper triangular matrix or Uplo::Lower for lower triangular matrix");
static_assert(std::is_same_v<ArgTrans, Trans::NoTranspose> || std::is_same_v<ArgTrans, Trans::Transpose> ||
std::is_same_v<ArgTrans, Trans::ConjTranspose>,
"KokkosBatched::tbsv: Use Trans::NoTranspose, Trans::Transpose or Trans::ConjTranspose");
static_assert(
std::is_same_v<ArgDiag, Diag::Unit> || std::is_same_v<ArgDiag, Diag::NonUnit>,
"KokkosBatched::tbsv: Use Diag::Unit for unit triangular matrix or Diag::NonUnit for non-unit triangular matrix");
static_assert(std::is_same_v<ArgAlgo, Algo::Tbsv::Unblocked>, "KokkosBatched::tbsv: Use Algo::Tbsv::Unblocked");
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &X, const int k);
};
Expand Down
Loading

0 comments on commit 9b4703d

Please sign in to comment.