Skip to content

Commit

Permalink
Merge pull request #2081 from eeprude/axpby_less_deep_copy
Browse files Browse the repository at this point in the history
Axpby using less deep copy (solves issue #2080)
  • Loading branch information
lucbv authored Jan 8, 2024
2 parents 3dafbed + c573d6e commit 93d4cda
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 28 deletions.
65 changes: 44 additions & 21 deletions blas/impl/KokkosBlas1_axpby_unification_attempt_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,31 @@ struct AxpbyUnificationAttemptTraits {
// - variable names begin with lower case letters
// - type names begin with upper case letters
// ********************************************************************
private:
public:
static constexpr bool onDevice =
KokkosKernels::Impl::kk_is_gpu_exec_space<tExecSpace>();

private:
static constexpr bool onHost = !onDevice;

public:
static constexpr bool a_is_scalar = !Kokkos::is_view_v<AV>;
static constexpr bool a_is_r0 = Tr0_val<AV>();
static constexpr bool a_is_r1s = Tr1s_val<AV>();
static constexpr bool a_is_r1d = Tr1d_val<AV>();

private:
static constexpr bool a_is_r0 = Tr0_val<AV>();
static constexpr bool a_is_r1s = Tr1s_val<AV>();
static constexpr bool a_is_r1d = Tr1d_val<AV>();

static constexpr bool x_is_r1 = Kokkos::is_view_v<XMV> && (XMV::rank == 1);
static constexpr bool x_is_r2 = Kokkos::is_view_v<XMV> && (XMV::rank == 2);

public:
static constexpr bool b_is_scalar = !Kokkos::is_view_v<BV>;
static constexpr bool b_is_r0 = Tr0_val<BV>();
static constexpr bool b_is_r1s = Tr1s_val<BV>();
static constexpr bool b_is_r1d = Tr1d_val<BV>();

private:
static constexpr bool b_is_r0 = Tr0_val<BV>();
static constexpr bool b_is_r1s = Tr1s_val<BV>();
static constexpr bool b_is_r1d = Tr1d_val<BV>();

static constexpr bool y_is_r1 = Kokkos::is_view_v<YMV> && (YMV::rank == 1);
static constexpr bool y_is_r2 = Kokkos::is_view_v<YMV> && (YMV::rank == 2);
Expand Down Expand Up @@ -220,10 +228,12 @@ struct AxpbyUnificationAttemptTraits {
// 'AtInputScalarTypeA_nonConst'
>;

using InternalTypeA_onDevice =
using InternalTypeA_onDevice = std::conditional_t<
a_is_scalar && b_is_scalar && onDevice, // Keep 'a' as scalar
InternalScalarTypeA,
Kokkos::View<const InternalScalarTypeA*, InternalLayoutA,
typename XMV::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
Kokkos::MemoryTraits<Kokkos::Unmanaged>>>;

using InternalTypeA_onHost = std::conditional_t<
(a_is_r1d || a_is_r1s) && xyRank2Case && onHost,
Expand Down Expand Up @@ -276,13 +286,15 @@ struct AxpbyUnificationAttemptTraits {
// 'AtInputScalarTypeB_nonConst'
>;

using InternalTypeB_onDevice =
using InternalTypeB_onDevice = std::conditional_t<
a_is_scalar && b_is_scalar && onDevice, // Keep 'b' as scalar
InternalScalarTypeB,
Kokkos::View<const InternalScalarTypeB*, InternalLayoutB,
typename YMV::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
Kokkos::MemoryTraits<Kokkos::Unmanaged>>>;

using InternalTypeB_onHost = std::conditional_t<
((b_is_r1d || b_is_r1s) && xyRank2Case && onHost),
(b_is_r1d || b_is_r1s) && xyRank2Case && onHost,
Kokkos::View<const InternalScalarTypeB*, InternalLayoutB,
typename YMV::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>,
Expand Down Expand Up @@ -614,7 +626,9 @@ struct AxpbyUnificationAttemptTraits {
}
} else {
if constexpr (xyRank1Case) {
constexpr bool internalTypeA_isOk = internalTypeA_is_r1d;
constexpr bool internalTypeA_isOk =
internalTypeA_is_r1d ||
(a_is_scalar && b_is_scalar && internalTypeA_is_scalar);
static_assert(
internalTypeA_isOk,
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
Expand All @@ -630,7 +644,9 @@ struct AxpbyUnificationAttemptTraits {
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
", onDevice, xyRank1Case: InternalTypeX is wrong");

constexpr bool internalTypeB_isOk = internalTypeB_is_r1d;
constexpr bool internalTypeB_isOk =
internalTypeB_is_r1d ||
(a_is_scalar && b_is_scalar && internalTypeA_is_scalar);
static_assert(
internalTypeB_isOk,
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
Expand All @@ -646,7 +662,9 @@ struct AxpbyUnificationAttemptTraits {
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
", onDevice, xyRank1Case: InternalTypeY is wrong");
} else {
constexpr bool internalTypeA_isOk = internalTypeA_is_r1d;
constexpr bool internalTypeA_isOk =
internalTypeA_is_r1d ||
(a_is_scalar && b_is_scalar && internalTypeA_is_scalar);
static_assert(
internalTypeA_isOk,
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
Expand All @@ -662,7 +680,9 @@ struct AxpbyUnificationAttemptTraits {
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
", onDevice, xyRank2Case: InternalTypeX is wrong");

constexpr bool internalTypeB_isOk = internalTypeB_is_r1d;
constexpr bool internalTypeB_isOk =
internalTypeB_is_r1d ||
(a_is_scalar && b_is_scalar && internalTypeB_is_scalar);
static_assert(
internalTypeB_isOk,
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
Expand Down Expand Up @@ -703,16 +723,19 @@ struct AxpbyUnificationAttemptTraits {
// ****************************************************************
// We are in the 'onDevice' case, with 2 possible subcases:
//
// 1) xyRank1Case, with only one possible situation:
// - [InternalTypeA / B] = [view<S_a*,1>, view<S_b*,1>]
// 1) xyRank1Case, with the following possible situations:
// - [InternalTypeA, B] = [S_a, S_b], or
// - [InternalTypeA, B] = [view<S_a*,1>, view<S_b*,1>]
//
// or
//
// 2) xyRank2Case, with only one possible situation:
// - [InternalTypeA / B] = [view<S_a*,1 / m>, view<S_b*,1 / m>]
// 2) xyRank2Case, with the following possible situations:
// - [InternalTypeA, B] = [S_a, S_b], or
// - [InternalTypeA, B] = [view<S_a*,1 / m>, view<S_b*,1 / m>]
// ****************************************************************
static_assert(
internalTypesAB_bothViews,
internalTypesAB_bothViews ||
(a_is_scalar && b_is_scalar && internalTypesAB_bothScalars),
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
", onDevice, invalid combination of types");
}
Expand Down
41 changes: 34 additions & 7 deletions blas/src/KokkosBlas1_axpby.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,41 @@ void axpby(const execution_space& exec_space, const AV& a, const XMV& X,
InternalTypeY internal_Y = Y;

if constexpr (AxpbyTraits::internalTypesAB_bothScalars) {
InternalTypeA internal_a(Impl::getScalarValueFromVariableAtHost<
AV, Impl::typeRank<AV>()>::getValue(a));
InternalTypeB internal_b(Impl::getScalarValueFromVariableAtHost<
BV, Impl::typeRank<BV>()>::getValue(b));
// ********************************************************************
// The unification logic applies the following general rules:
// 1) In a 'onHost' case, it makes the internal types for 'a' and 'b'
// to be both scalars (hence the name 'internalTypesAB_bothScalars')
// 2) In a 'onDevice' case, it makes the internal types for 'a' and 'b'
// to be Kokkos views. For performance reasons in Trilinos, the only
// exception for this rule is when the input types for both 'a' and
// 'b' are already scalars, in which case the internal types for 'a'
// and 'b' become scalars as well, eventually changing precision in
// order to match the precisions of 'X' and 'Y'.
// ********************************************************************
if constexpr (AxpbyTraits::a_is_scalar && AxpbyTraits::b_is_scalar &&
AxpbyTraits::onDevice) {
// ******************************************************************
// We are in the exception situation for rule 2
// ******************************************************************
InternalTypeA internal_a(a);
InternalTypeA internal_b(b);

Impl::Axpby<execution_space, InternalTypeA, InternalTypeX, InternalTypeB,
InternalTypeY>::axpby(exec_space, internal_a, internal_X,
internal_b, internal_Y);
Impl::Axpby<execution_space, InternalTypeA, InternalTypeX, InternalTypeB,
InternalTypeY>::axpby(exec_space, internal_a, internal_X,
internal_b, internal_Y);
} else {
// ******************************************************************
// We are in rule 1, that is, we are in a 'onHost' case now
// ******************************************************************
InternalTypeA internal_a(Impl::getScalarValueFromVariableAtHost<
AV, Impl::typeRank<AV>()>::getValue(a));
InternalTypeB internal_b(Impl::getScalarValueFromVariableAtHost<
BV, Impl::typeRank<BV>()>::getValue(b));

Impl::Axpby<execution_space, InternalTypeA, InternalTypeX, InternalTypeB,
InternalTypeY>::axpby(exec_space, internal_a, internal_X,
internal_b, internal_Y);
}
} else if constexpr (AxpbyTraits::internalTypesAB_bothViews) {
constexpr bool internalLayoutA_isStride(
std::is_same_v<typename InternalTypeA::array_layout,
Expand Down

0 comments on commit 93d4cda

Please sign in to comment.