diff --git a/lapack/tpls/KokkosLapack_svd_tpl_spec_decl.hpp b/lapack/tpls/KokkosLapack_svd_tpl_spec_decl.hpp index 8bcfc1788a..d3d32a35b2 100644 --- a/lapack/tpls/KokkosLapack_svd_tpl_spec_decl.hpp +++ b/lapack/tpls/KokkosLapack_svd_tpl_spec_decl.hpp @@ -219,17 +219,18 @@ void mklSvdWrapper(const ExecutionSpace& /* space */, const char jobu[], using ULayout_t = typename UMatrix::array_layout; using VLayout_t = typename VMatrix::array_layout; - const lapack_int m = A.extent_int(0); - const lapack_int n = A.extent_int(1); - const lapack_int lda = std::is_same_v - ? A.stride(0) - : A.stride(1); - const lapack_int ldu = std::is_same_v - ? U.stride(0) - : U.stride(1); - const lapack_int ldvt = std::is_same_v - ? Vt.stride(0) - : Vt.stride(1); + static_assert(std::is_same_v, + "KokkosLapack - svd: A needs to have a Kokkos::LayoutLeft"); + static_assert(std::is_same_v, + "KokkosLapack - svd: U needs to have a Kokkos::LayoutLeft"); + static_assert(std::is_same_v, + "KokkosLapack - svd: Vt needs to have a Kokkos::LayoutLeft"); + + const lapack_int m = A.extent_int(0); + const lapack_int n = A.extent_int(1); + const lapack_int lda = A.stride(1); + const lapack_int ldu = U.stride(1); + const lapack_int ldvt = Vt.stride(1); Kokkos::View rwork("svd rwork buffer", Kokkos::min(m, n) - 1); @@ -376,15 +377,18 @@ void cusolverSvdWrapper(const ExecutionSpace& space, const char jobu[], using ULayout_t = typename UMatrix::array_layout; using VLayout_t = typename VMatrix::array_layout; - const int m = A.extent_int(0); - const int n = A.extent_int(1); - const int lda = std::is_same_v ? A.stride(0) - : A.stride(1); - const int ldu = std::is_same_v ? U.stride(0) - : U.stride(1); - const int ldvt = std::is_same_v - ? Vt.stride(0) - : Vt.stride(1); + static_assert(std::is_same_v, + "KokkosLapack - svd: A needs to have a Kokkos::LayoutLeft"); + static_assert(std::is_same_v, + "KokkosLapack - svd: U needs to have a Kokkos::LayoutLeft"); + static_assert(std::is_same_v, + "KokkosLapack - svd: Vt needs to have a Kokkos::LayoutLeft"); + + const int m = A.extent_int(0); + const int n = A.extent_int(1); + const int lda = A.stride(1); + const int ldu = U.stride(1); + const int ldvt = Vt.stride(1); int lwork = 0; Kokkos::View info("svd info"); @@ -536,17 +540,18 @@ void rocsolverSvdWrapper(const ExecutionSpace& space, const char jobu[], using ULayout_t = typename UMatrix::array_layout; using VLayout_t = typename VMatrix::array_layout; - const rocblas_int m = A.extent_int(0); - const rocblas_int n = A.extent_int(1); - const rocblas_int lda = std::is_same_v - ? A.stride(0) - : A.stride(1); - const rocblas_int ldu = std::is_same_v - ? U.stride(0) - : U.stride(1); - const rocblas_int ldvt = std::is_same_v - ? Vt.stride(0) - : Vt.stride(1); + static_assert(std::is_same_v, + "KokkosLapack - svd: A needs to have a Kokkos::LayoutLeft"); + static_assert(std::is_same_v, + "KokkosLapack - svd: U needs to have a Kokkos::LayoutLeft"); + static_assert(std::is_same_v, + "KokkosLapack - svd: Vt needs to have a Kokkos::LayoutLeft"); + + const rocblas_int m = A.extent_int(0); + const rocblas_int n = A.extent_int(1); + const rocblas_int lda = A.stride(1); + const rocblas_int ldu = U.stride(1); + const rocblas_int ldvt = Vt.stride(1); rocblas_svect UVecMode = rocblas_svect_all; if ((jobu[0] == 'S') || (jobu[0] == 's')) {