forked from kokkos/kokkos-kernels
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement batched serial syr (kokkos#2497)
* Introduce OpReal functor to provide real operator Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp> * implement batched serial syr Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp> * remove unused variable Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp> * fix view constructor Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp> * Add docstring and assertion for ArgUplo and ArgTrans parameters Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp> --------- Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp> Co-authored-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
- Loading branch information
1 parent
609667f
commit fbb9b8d
Showing
6 changed files
with
733 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
//@HEADER | ||
// ************************************************************************ | ||
// | ||
// Kokkos v. 4.0 | ||
// Copyright (2022) National Technology & Engineering | ||
// Solutions of Sandia, LLC (NTESS). | ||
// | ||
// Under the terms of Contract DE-NA0003525 with NTESS, | ||
// the U.S. Government retains certain rights in this software. | ||
// | ||
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://kokkos.org/LICENSE for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//@HEADER | ||
|
||
#ifndef KOKKOSBATCHED_SYR_SERIAL_IMPL_HPP_ | ||
#define KOKKOSBATCHED_SYR_SERIAL_IMPL_HPP_ | ||
|
||
#include <KokkosBlas_util.hpp> | ||
#include <KokkosBatched_Util.hpp> | ||
#include "KokkosBatched_Syr_Serial_Internal.hpp" | ||
|
||
namespace KokkosBatched { | ||
namespace Impl { | ||
template <typename XViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int checkSyrInput([[maybe_unused]] const XViewType &x, | ||
[[maybe_unused]] const AViewType &A) { | ||
static_assert(Kokkos::is_view_v<XViewType>, "KokkosBatched::syr: XViewType is not a Kokkos::View."); | ||
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::syr: AViewType is not a Kokkos::View."); | ||
static_assert(XViewType::rank == 1, "KokkosBatched::syr: XViewType must have rank 1."); | ||
static_assert(AViewType::rank == 2, "KokkosBatched::syr: AViewType must have rank 2."); | ||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
const int lda = A.extent_int(0), n = A.extent_int(1); | ||
|
||
if (n < 0) { | ||
Kokkos::printf( | ||
"KokkosBatched::syr: input parameter n must not be less than 0: n " | ||
"= " | ||
"%d\n", | ||
n); | ||
return 1; | ||
} | ||
|
||
if (x.extent_int(0) != n) { | ||
Kokkos::printf( | ||
"KokkosBatched::syr: x must contain n elements: n " | ||
"= " | ||
"%d\n", | ||
n); | ||
return 1; | ||
} | ||
|
||
if (lda < Kokkos::max(1, n)) { | ||
Kokkos::printf( | ||
"KokkosBatched::syr: leading dimension of A must not be smaller than " | ||
"max(1, n): " | ||
"lda = %d, n = %d\n", | ||
lda, n); | ||
return 1; | ||
} | ||
#endif | ||
return 0; | ||
} | ||
} // namespace Impl | ||
|
||
// {s,d,c,z}syr interface | ||
// L T | ||
// A: alpha * x * x**T + A | ||
template <> | ||
struct SerialSyr<Uplo::Lower, Trans::Transpose> { | ||
template <typename ScalarType, typename XViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) { | ||
// Quick return if possible | ||
const int n = A.extent_int(1); | ||
if (n == 0 || (alpha == ScalarType(0))) return 0; | ||
|
||
auto info = Impl::checkSyrInput(x, A); | ||
if (info) return info; | ||
|
||
return Impl::SerialSyrInternalLower::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), n, alpha, x.data(), | ||
x.stride(0), A.data(), A.stride(0), A.stride(1)); | ||
} | ||
}; | ||
|
||
// {s,d,c,z}syr interface | ||
// U T | ||
// A: alpha * x * x**T + A | ||
template <> | ||
struct SerialSyr<Uplo::Upper, Trans::Transpose> { | ||
template <typename ScalarType, typename XViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) { | ||
// Quick return if possible | ||
const int n = A.extent_int(1); | ||
if (n == 0 || (alpha == ScalarType(0))) return 0; | ||
|
||
auto info = Impl::checkSyrInput(x, A); | ||
if (info) return info; | ||
|
||
return Impl::SerialSyrInternalUpper::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), n, alpha, x.data(), | ||
x.stride(0), A.data(), A.stride(0), A.stride(1)); | ||
} | ||
}; | ||
|
||
// {c,z}her interface | ||
// L C | ||
// A: alpha * x * x**H + A | ||
template <> | ||
struct SerialSyr<Uplo::Lower, Trans::ConjTranspose> { | ||
template <typename ScalarType, typename XViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) { | ||
// Quick return if possible | ||
const int n = A.extent_int(1); | ||
if (n == 0 || (alpha == ScalarType(0))) return 0; | ||
|
||
auto info = Impl::checkSyrInput(x, A); | ||
if (info) return info; | ||
|
||
return Impl::SerialSyrInternalLower::invoke(KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpReal(), n, alpha, | ||
x.data(), x.stride(0), A.data(), A.stride(0), A.stride(1)); | ||
} | ||
}; | ||
|
||
// {c,z}her interface | ||
// U C | ||
// A: alpha * x * x**H + A | ||
template <> | ||
struct SerialSyr<Uplo::Upper, Trans::ConjTranspose> { | ||
template <typename ScalarType, typename XViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) { | ||
// Quick return if possible | ||
const int n = A.extent_int(1); | ||
if (n == 0 || (alpha == ScalarType(0))) return 0; | ||
|
||
auto info = Impl::checkSyrInput(x, A); | ||
if (info) return info; | ||
|
||
return Impl::SerialSyrInternalUpper::invoke(KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpReal(), n, alpha, | ||
x.data(), x.stride(0), A.data(), A.stride(0), A.stride(1)); | ||
} | ||
}; | ||
|
||
} // namespace KokkosBatched | ||
|
||
#endif // KOKKOSBATCHED_SYR_SERIAL_IMPL_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
//@HEADER | ||
// ************************************************************************ | ||
// | ||
// Kokkos v. 4.0 | ||
// Copyright (2022) National Technology & Engineering | ||
// Solutions of Sandia, LLC (NTESS). | ||
// | ||
// Under the terms of Contract DE-NA0003525 with NTESS, | ||
// the U.S. Government retains certain rights in this software. | ||
// | ||
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://kokkos.org/LICENSE for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//@HEADER | ||
|
||
#ifndef KOKKOSBATCHED_SYR_SERIAL_INTERNAL_HPP_ | ||
#define KOKKOSBATCHED_SYR_SERIAL_INTERNAL_HPP_ | ||
|
||
#include <KokkosBatched_Util.hpp> | ||
|
||
namespace KokkosBatched { | ||
namespace Impl { | ||
|
||
/// | ||
/// Serial Internal Impl | ||
/// ==================== | ||
|
||
/// Lower | ||
|
||
struct SerialSyrInternalLower { | ||
template <typename Op, typename SymOp, typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT x, const int xs0, | ||
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1); | ||
}; | ||
|
||
template <typename Op, typename SymOp, typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION int SerialSyrInternalLower::invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT x, const int xs0, | ||
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) { | ||
for (int j = 0; j < an; j++) { | ||
if (x[j * xs0] != ValueType(0)) { | ||
auto temp = alpha * op(x[j * xs0]); | ||
A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1] + x[j * xs0] * temp); | ||
for (int i = j + 1; i < an; i++) { | ||
A[i * as0 + j * as1] += x[i * xs0] * temp; | ||
} | ||
} else { | ||
A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1]); | ||
} | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
/// Upper | ||
|
||
struct SerialSyrInternalUpper { | ||
template <typename Op, typename SymOp, typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT x, const int xs0, | ||
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1); | ||
}; | ||
|
||
template <typename Op, typename SymOp, typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION int SerialSyrInternalUpper::invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT x, const int xs0, | ||
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) { | ||
for (int j = 0; j < an; j++) { | ||
if (x[j * xs0] != ValueType(0)) { | ||
auto temp = alpha * op(x[j * xs0]); | ||
for (int i = 0; i < j; i++) { | ||
A[i * as0 + j * as1] += x[i * xs0] * temp; | ||
} | ||
A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1] + x[j * xs0] * temp); | ||
} else { | ||
A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1]); | ||
} | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
} // namespace Impl | ||
} // namespace KokkosBatched | ||
|
||
#endif // KOKKOSBATCHED_SYR_SERIAL_INTERNAL_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
//@HEADER | ||
// ************************************************************************ | ||
// | ||
// Kokkos v. 4.0 | ||
// Copyright (2022) National Technology & Engineering | ||
// Solutions of Sandia, LLC (NTESS). | ||
// | ||
// Under the terms of Contract DE-NA0003525 with NTESS, | ||
// the U.S. Government retains certain rights in this software. | ||
// | ||
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://kokkos.org/LICENSE for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//@HEADER | ||
#ifndef KOKKOSBATCHED_SYR_HPP_ | ||
#define KOKKOSBATCHED_SYR_HPP_ | ||
|
||
#include <KokkosBatched_Util.hpp> | ||
|
||
/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) | ||
|
||
namespace KokkosBatched { | ||
|
||
/// \brief Serial Batched Syr: | ||
/// Performs the symmetric rank 1 operation | ||
/// A := alpha*x*x**T + A or A := alpha*x*x**H + A | ||
/// where alpha is a scalar, x is an n element vector, and A is a n by n symmetric or Hermitian matrix. | ||
/// | ||
/// \tparam ArgUplo: Type indicating whether the upper (Uplo::Upper) or lower (Uplo::Lower) triangular part of A is | ||
/// modified | ||
/// \tparam ArgTrans: Type indicating whether the transpose (Trans::Transpose) or conjugate transpose | ||
/// (Trans::ConjTranspose) of x is used | ||
/// | ||
/// \tparam ScalarType: Input type for the scalar alpha | ||
/// \tparam XViewType: Input type for the vector x, needs to be a 1D view | ||
/// \tparam AViewType: Input/output type for the matrix A, needs to be a 2D view | ||
/// | ||
/// \param alpha [in]: alpha is a scalar | ||
/// \param x [in]: x is a length n vector, a rank 1 view | ||
/// \param A [inout]: A is a n by n matrix, a rank 2 view | ||
/// | ||
/// No nested parallel_for is used inside of the function. | ||
/// | ||
template <typename ArgUplo, typename ArgTrans> | ||
struct SerialSyr { | ||
static_assert( | ||
std::is_same_v<ArgUplo, Uplo::Upper> || std::is_same_v<ArgUplo, Uplo::Lower>, | ||
"KokkosBatched::syr: Use Uplo::Upper for upper triangular matrix or Uplo::Lower for lower triangular matrix"); | ||
static_assert(std::is_same_v<ArgTrans, Trans::Transpose> || std::is_same_v<ArgTrans, Trans::ConjTranspose>, | ||
"KokkosBatched::syr: Use Trans::Transpose for {s,d,c,z}syr or Trans::ConjTranspose for {c,z}her"); | ||
template <typename ScalarType, typename XViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &a); | ||
}; | ||
} // namespace KokkosBatched | ||
|
||
#include "KokkosBatched_Syr_Serial_Impl.hpp" | ||
|
||
#endif // KOKKOSBATCHED_SYR_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.