Skip to content

Commit 1a0c46b

Browse files
yasahi-hpcYuuichi Asahi
and
Yuuichi Asahi
authored
Deprecated remarks to trsv serial impl (#2461)
* Add deprecated warnings for older interfaces of serial trsv Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp> * use if constexpr for selective interface of trsv Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp> * format Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp> * format Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp> --------- Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp> Co-authored-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
1 parent f23260a commit 1a0c46b

File tree

2 files changed

+36
-15
lines changed

2 files changed

+36
-15
lines changed

batched/dense/impl/KokkosBatched_Trsv_Serial_Internal.hpp

+21
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,27 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(
214214
}
215215

216216
} // namespace Impl
217+
218+
template <typename AlgoType>
219+
struct [[deprecated("Use KokkosBatched::SerialTrsv instead")]] SerialTrsvInternalLower {
220+
template <typename ScalarType, typename ValueType>
221+
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const int m, const ScalarType alpha,
222+
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
223+
/**/ ValueType *KOKKOS_RESTRICT b, const int bs0) {
224+
return Impl::SerialTrsvInternalLower<AlgoType>::invoke(use_unit_diag, false, m, alpha, A, as0, as1, b, bs0);
225+
}
226+
};
227+
228+
template <typename AlgoType>
229+
struct [[deprecated("Use KokkosBatched::SerialTrsv instead")]] SerialTrsvInternalUpper {
230+
template <typename ScalarType, typename ValueType>
231+
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const int m, const ScalarType alpha,
232+
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
233+
/**/ ValueType *KOKKOS_RESTRICT b, const int bs0) {
234+
return Impl::SerialTrsvInternalUpper<AlgoType>::invoke(use_unit_diag, false, m, alpha, A, as0, as1, b, bs0);
235+
}
236+
};
237+
217238
} // namespace KokkosBatched
218239

219240
#endif

batched/dense/src/KokkosBatched_Trsv_Decl.hpp

+15-15
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ struct Trsv {
7676
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, const ScalarType alpha, const AViewType &A,
7777
const bViewType &b) {
7878
int r_val = 0;
79-
if (std::is_same_v<ArgMode, Mode::Serial>) {
79+
if constexpr (std::is_same_v<ArgMode, Mode::Serial>) {
8080
r_val = SerialTrsv<ArgUplo, ArgTrans, ArgDiag, ArgAlgo>::invoke(alpha, A, b);
81-
} else if (std::is_same_v<ArgMode, Mode::Team>) {
81+
} else if constexpr (std::is_same_v<ArgMode, Mode::Team>) {
8282
r_val = TeamTrsv<MemberType, ArgUplo, ArgTrans, ArgDiag, ArgAlgo>::invoke(member, alpha, A, b);
83-
} else if (std::is_same_v<ArgMode, Mode::TeamVector>) {
83+
} else if constexpr (std::is_same_v<ArgMode, Mode::TeamVector>) {
8484
r_val = TeamVectorTrsv<MemberType, ArgUplo, ArgTrans, ArgDiag, ArgAlgo>::invoke(member, alpha, A, b);
8585
}
8686
return r_val;
@@ -155,46 +155,46 @@ struct Trsv {
155155

156156
#define KOKKOSBATCHED_TRSV_LOWER_NO_TRANSPOSE_INTERNAL_INVOKE(MODETYPE, ALGOTYPE, MEMBER, DIAG, M, N, ALPHA, A, AS0, \
157157
AS1, B, BS) \
158-
if (std::is_same_v<MODETYPE, KokkosBatched::Mode::Serial>) { \
158+
if constexpr (std::is_same_v<MODETYPE, KokkosBatched::Mode::Serial>) { \
159159
KOKKOSBATCHED_SERIAL_TRSV_LOWER_NO_TRANSPOSE_INTERNAL_INVOKE(ALGOTYPE, DIAG, M, N, ALPHA, A, AS0, AS1, B, BS); \
160-
} else if (std::is_same_v<MODETYPE, KokkosBatched::Mode::Team>) { \
160+
} else if constexpr (std::is_same_v<MODETYPE, KokkosBatched::Mode::Team>) { \
161161
KOKKOSBATCHED_TEAM_TRSV_LOWER_NO_TRANSPOSE_INTERNAL_INVOKE(ALGOTYPE, MEMBER, DIAG, M, N, ALPHA, A, AS0, AS1, B, \
162162
BS); \
163-
} else if (std::is_same_v<MODETYPE, KokkosBatched::Mode::TeamVector>) { \
163+
} else if constexpr (std::is_same_v<MODETYPE, KokkosBatched::Mode::TeamVector>) { \
164164
KOKKOSBATCHED_TEAMVECTOR_TRSV_LOWER_NO_TRANSPOSE_INTERNAL_INVOKE(ALGOTYPE, MEMBER, DIAG, M, N, ALPHA, A, AS0, AS1, \
165165
B, BS); \
166166
}
167167

168168
#define KOKKOSBATCHED_TRSV_LOWER_TRANSPOSE_INTERNAL_INVOKE(MODETYPE, ALGOTYPE, MEMBER, DIAG, M, N, ALPHA, A, AS0, AS1, \
169169
B, BS) \
170-
if (std::is_same_v<MODETYPE, KokkosBatched::Mode::Serial>) { \
170+
if constexpr (std::is_same_v<MODETYPE, KokkosBatched::Mode::Serial>) { \
171171
KOKKOSBATCHED_SERIAL_TRSV_LOWER_TRANSPOSE_INTERNAL_INVOKE(ALGOTYPE, DIAG, M, N, ALPHA, A, AS0, AS1, B, BS); \
172-
} else if (std::is_same_v<MODETYPE, KokkosBatched::Mode::Team>) { \
172+
} else if constexpr (std::is_same_v<MODETYPE, KokkosBatched::Mode::Team>) { \
173173
KOKKOSBATCHED_TEAM_TRSV_LOWER_TRANSPOSE_INTERNAL_INVOKE(ALGOTYPE, MEMBER, DIAG, M, N, ALPHA, A, AS0, AS1, B, BS); \
174-
} else if (std::is_same_v<MODETYPE, KokkosBatched::Mode::TeamVector>) { \
174+
} else if constexpr (std::is_same_v<MODETYPE, KokkosBatched::Mode::TeamVector>) { \
175175
KOKKOSBATCHED_TEAMVECTOR_TRSV_LOWER_TRANSPOSE_INTERNAL_INVOKE(ALGOTYPE, MEMBER, DIAG, M, N, ALPHA, A, AS0, AS1, B, \
176176
BS); \
177177
}
178178

179179
#define KOKKOSBATCHED_TRSV_UPPER_NO_TRANSPOSE_INTERNAL_INVOKE(MODETYPE, ALGOTYPE, MEMBER, DIAG, M, N, ALPHA, A, AS0, \
180180
AS1, B, BS) \
181-
if (std::is_same_v<MODETYPE, KokkosBatched::Mode::Serial>) { \
181+
if constexpr (std::is_same_v<MODETYPE, KokkosBatched::Mode::Serial>) { \
182182
KOKKOSBATCHED_SERIAL_TRSV_UPPER_NO_TRANSPOSE_INTERNAL_INVOKE(ALGOTYPE, DIAG, M, N, ALPHA, A, AS0, AS1, B, BS); \
183-
} else if (std::is_same_v<MODETYPE, KokkosBatched::Mode::Team>) { \
183+
} else if constexpr (std::is_same_v<MODETYPE, KokkosBatched::Mode::Team>) { \
184184
KOKKOSBATCHED_TEAM_TRSV_UPPER_NO_TRANSPOSE_INTERNAL_INVOKE(ALGOTYPE, MEMBER, DIAG, M, N, ALPHA, A, AS0, AS1, B, \
185185
BS); \
186-
} else if (std::is_same_v<MODETYPE, KokkosBatched::Mode::TeamVector>) { \
186+
} else if constexpr (std::is_same_v<MODETYPE, KokkosBatched::Mode::TeamVector>) { \
187187
KOKKOSBATCHED_TEAMVECTOR_TRSV_UPPER_NO_TRANSPOSE_INTERNAL_INVOKE(ALGOTYPE, MEMBER, DIAG, M, N, ALPHA, A, AS0, AS1, \
188188
B, BS); \
189189
}
190190

191191
#define KOKKOSBATCHED_TRSV_UPPER_TRANSPOSE_INTERNAL_INVOKE(MODETYPE, ALGOTYPE, MEMBER, DIAG, M, N, ALPHA, A, AS0, AS1, \
192192
B, BS) \
193-
if (std::is_same_v<MODETYPE, KokkosBatched::Mode::Serial>) { \
193+
if constexpr (std::is_same_v<MODETYPE, KokkosBatched::Mode::Serial>) { \
194194
KOKKOSBATCHED_SERIAL_TRSV_UPPER_TRANSPOSE_INTERNAL_INVOKE(ALGOTYPE, DIAG, M, N, ALPHA, A, AS0, AS1, B, BS); \
195-
} else if (std::is_same_v<MODETYPE, KokkosBatched::Mode::Team>) { \
195+
} else if constexpr (std::is_same_v<MODETYPE, KokkosBatched::Mode::Team>) { \
196196
KOKKOSBATCHED_TEAM_TRSV_UPPER_TRANSPOSE_INTERNAL_INVOKE(ALGOTYPE, MEMBER, DIAG, M, N, ALPHA, A, AS0, AS1, B, BS); \
197-
} else if (std::is_same_v<MODETYPE, KokkosBatched::Mode::TeamVector>) { \
197+
} else if constexpr (std::is_same_v<MODETYPE, KokkosBatched::Mode::TeamVector>) { \
198198
KOKKOSBATCHED_TEAMVECTOR_TRSV_UPPER_TRANSPOSE_INTERNAL_INVOKE(ALGOTYPE, MEMBER, DIAG, M, N, ALPHA, A, AS0, AS1, B, \
199199
BS); \
200200
}

0 commit comments

Comments
 (0)