Skip to content

Commit

Permalink
Lapack - SVD: fixing more TPL static_assert for LayoutLeft
Browse files Browse the repository at this point in the history
  • Loading branch information
lucbv committed Feb 5, 2024
1 parent a9048fb commit 7fdefc9
Showing 1 changed file with 36 additions and 31 deletions.
67 changes: 36 additions & 31 deletions lapack/tpls/KokkosLapack_svd_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,18 @@ void mklSvdWrapper(const ExecutionSpace& /* space */, const char jobu[],
using ULayout_t = typename UMatrix::array_layout;
using VLayout_t = typename VMatrix::array_layout;

const lapack_int m = A.extent_int(0);
const lapack_int n = A.extent_int(1);
const lapack_int lda = std::is_same_v<ALayout_t, Kokkos::LayoutRight>
? A.stride(0)
: A.stride(1);
const lapack_int ldu = std::is_same_v<ULayout_t, Kokkos::LayoutRight>
? U.stride(0)
: U.stride(1);
const lapack_int ldvt = std::is_same_v<VLayout_t, Kokkos::LayoutRight>
? Vt.stride(0)
: Vt.stride(1);
static_assert(std::is_same_v<ALayout_t, Kokkos::LayoutLeft>,
"KokkosLapack - svd: A needs to have a Kokkos::LayoutLeft");
static_assert(std::is_same_v<ULayout_t, Kokkos::LayoutLeft>,
"KokkosLapack - svd: U needs to have a Kokkos::LayoutLeft");
static_assert(std::is_same_v<VLayout_t, Kokkos::LayoutLeft>,
"KokkosLapack - svd: Vt needs to have a Kokkos::LayoutLeft");

const lapack_int m = A.extent_int(0);
const lapack_int n = A.extent_int(1);
const lapack_int lda = A.stride(1);
const lapack_int ldu = U.stride(1);
const lapack_int ldvt = Vt.stride(1);

Kokkos::View<Magnitude*, memory_space> rwork("svd rwork buffer",
Kokkos::min(m, n) - 1);
Expand Down Expand Up @@ -376,15 +377,18 @@ void cusolverSvdWrapper(const ExecutionSpace& space, const char jobu[],
using ULayout_t = typename UMatrix::array_layout;
using VLayout_t = typename VMatrix::array_layout;

const int m = A.extent_int(0);
const int n = A.extent_int(1);
const int lda = std::is_same_v<ALayout_t, Kokkos::LayoutRight> ? A.stride(0)
: A.stride(1);
const int ldu = std::is_same_v<ULayout_t, Kokkos::LayoutRight> ? U.stride(0)
: U.stride(1);
const int ldvt = std::is_same_v<VLayout_t, Kokkos::LayoutRight>
? Vt.stride(0)
: Vt.stride(1);
static_assert(std::is_same_v<ALayout_t, Kokkos::LayoutLeft>,
"KokkosLapack - svd: A needs to have a Kokkos::LayoutLeft");
static_assert(std::is_same_v<ULayout_t, Kokkos::LayoutLeft>,
"KokkosLapack - svd: U needs to have a Kokkos::LayoutLeft");
static_assert(std::is_same_v<VLayout_t, Kokkos::LayoutLeft>,
"KokkosLapack - svd: Vt needs to have a Kokkos::LayoutLeft");

const int m = A.extent_int(0);
const int n = A.extent_int(1);
const int lda = A.stride(1);
const int ldu = U.stride(1);
const int ldvt = Vt.stride(1);

int lwork = 0;
Kokkos::View<int, memory_space> info("svd info");
Expand Down Expand Up @@ -536,17 +540,18 @@ void rocsolverSvdWrapper(const ExecutionSpace& space, const char jobu[],
using ULayout_t = typename UMatrix::array_layout;
using VLayout_t = typename VMatrix::array_layout;

const rocblas_int m = A.extent_int(0);
const rocblas_int n = A.extent_int(1);
const rocblas_int lda = std::is_same_v<ALayout_t, Kokkos::LayoutRight>
? A.stride(0)
: A.stride(1);
const rocblas_int ldu = std::is_same_v<ULayout_t, Kokkos::LayoutRight>
? U.stride(0)
: U.stride(1);
const rocblas_int ldvt = std::is_same_v<VLayout_t, Kokkos::LayoutRight>
? Vt.stride(0)
: Vt.stride(1);
static_assert(std::is_same_v<ALayout_t, Kokkos::LayoutLeft>,
"KokkosLapack - svd: A needs to have a Kokkos::LayoutLeft");
static_assert(std::is_same_v<ULayout_t, Kokkos::LayoutLeft>,
"KokkosLapack - svd: U needs to have a Kokkos::LayoutLeft");
static_assert(std::is_same_v<VLayout_t, Kokkos::LayoutLeft>,
"KokkosLapack - svd: Vt needs to have a Kokkos::LayoutLeft");

const rocblas_int m = A.extent_int(0);
const rocblas_int n = A.extent_int(1);
const rocblas_int lda = A.stride(1);
const rocblas_int ldu = U.stride(1);
const rocblas_int ldvt = Vt.stride(1);

rocblas_svect UVecMode = rocblas_svect_all;
if ((jobu[0] == 'S') || (jobu[0] == 's')) {
Expand Down

0 comments on commit 7fdefc9

Please sign in to comment.