Skip to content

Commit eafa364

Browse files
committed
#2240: Add helpers and use Kokkos::View for internals of Rabenseifner when user's payload is View
1 parent 0f23075 commit eafa364

File tree

6 files changed

+417
-165
lines changed

6 files changed

+417
-165
lines changed

src/vt/collective/reduce/allreduce/data_handler.h

+2-28
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,11 @@ class DataHandler<ScalarType, typename std::enable_if<std::is_arithmetic<ScalarT
6666

6767
static std::vector<ScalarType> toVec(const ScalarType& data) { return std::vector<ScalarType>{data}; }
6868
static ScalarType fromVec(const std::vector<ScalarType>& data) { return data[0]; }
69-
static ScalarType fromMemory(ScalarType* data, size_t) {
69+
static ScalarType fromMemory(const ScalarType* data, size_t) {
7070
return *data;
7171
}
7272

73-
// static const ScalarType* data(const ScalarType& data) { return &data; }
7473
static size_t size(const ScalarType&) { return 1; }
75-
// static ScalarType& at(ScalarType& data, size_t) { return data; }
76-
// static void set(ScalarType& data, size_t, const ScalarType& value) { data = value; }
77-
// static ScalarType split(ScalarType&, size_t, size_t) { return ScalarType{}; }
7874
};
7975

8076
template <typename T>
@@ -84,20 +80,11 @@ class DataHandler<std::vector<T>> {
8480

8581
static const std::vector<T>& toVec(const std::vector<T>& data) { return data; }
8682
static std::vector<T> fromVec(const std::vector<T>& data) { return data; }
87-
static std::vector<T> fromMemory(T* data, size_t count) {
83+
static std::vector<T> fromMemory(const T* data, size_t count) {
8884
return std::vector<T>(data, data + count);
8985
}
9086

91-
// static const T* data(const std::vector<T>& data) {return data.data(); }
9287
static size_t size(const std::vector<T>& data) { return data.size(); }
93-
// static T at(const std::vector<T>& data, size_t idx) { return data[idx]; }
94-
// static T& at(std::vector<T>& data, size_t idx) { return data[idx]; }
95-
// static void set(std::vector<T>& data, size_t idx, const T& value) {
96-
// data[idx] = value;
97-
// }
98-
// static std::vector<T> split(std::vector<T>& data, size_t start, size_t end) {
99-
// return std::vector<T>{data.begin() + start, data.begin() + end};
100-
// }
10188
};
10289

10390
#if MAGISTRATE_KOKKOS_ENABLED
@@ -129,20 +116,7 @@ class DataHandler<Kokkos::View<T*, Kokkos::HostSpace, Props...>> {
129116
return view;
130117
}
131118

132-
// static const T* data(const ViewType& data) {return data.data(); }
133119
static size_t size(const ViewType& data) { return data.extent(0); }
134-
135-
// static T at(const ViewType& data, size_t idx) { return data(idx); }
136-
137-
// static T& at(ViewType& data, size_t idx) { return data(idx); }
138-
139-
// static void set(ViewType& data, size_t idx, const T& value) {
140-
// data(idx) = value;
141-
// }
142-
143-
// static ViewType split(ViewType& data, size_t start, size_t end) {
144-
// return Kokkos::subview(data, std::make_pair(start, end));
145-
// }
146120
};
147121

148122
#endif // MAGISTRATE_KOKKOS_ENABLED
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/*
2+
//@HEADER
3+
// *****************************************************************************
4+
//
5+
// helpers.h
6+
// DARMA/vt => Virtual Transport
7+
//
8+
// Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC
9+
// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S.
10+
// Government retains certain rights in this software.
11+
//
12+
// Redistribution and use in source and binary forms, with or without
13+
// modification, are permitted provided that the following conditions are met:
14+
//
15+
// * Redistributions of source code must retain the above copyright notice,
16+
// this list of conditions and the following disclaimer.
17+
//
18+
// * Redistributions in binary form must reproduce the above copyright notice,
19+
// this list of conditions and the following disclaimer in the documentation
20+
// and/or other materials provided with the distribution.
21+
//
22+
// * Neither the name of the copyright holder nor the names of its
23+
// contributors may be used to endorse or promote products derived from this
24+
// software without specific prior written permission.
25+
//
26+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
27+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
29+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
30+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
31+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
32+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
33+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
34+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
35+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
36+
// POSSIBILITY OF SUCH DAMAGE.
37+
//
38+
// Questions? Contact darma@sandia.gov
39+
//
40+
// *****************************************************************************
41+
//@HEADER
42+
*/
43+
44+
#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_HELPERS_H
45+
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_HELPERS_H
46+
#include "data_handler.h"
47+
#include "rabenseifner_msg.h"
48+
#include "vt/messaging/message/shared_message.h"
49+
#include <vector>
50+
51+
namespace vt::collective::reduce::allreduce {
52+
53+
template <typename Scalar, typename DataT>
54+
struct DataHelper {
55+
using DataType = DataHandler<DataT>;
56+
57+
template <typename... Args>
58+
static void assign(std::vector<Scalar>& dest, Args&&... data) {
59+
dest = DataHandler<DataT>::toVec(std::forward<Args>(data)...);
60+
}
61+
62+
static MsgPtr<RabenseifnerMsg<Scalar, DataT>> createMessage(
63+
const std::vector<Scalar>& payload, size_t begin, size_t count, size_t id,
64+
int32_t step = 0) {
65+
return vt::makeMessage<RabenseifnerMsg<Scalar, DataT>>(
66+
payload.data() + begin, count, id, step);
67+
}
68+
69+
static void copy(
70+
std::vector<Scalar>& dest, size_t start_idx, RabenseifnerMsg<Scalar, DataT>* msg) {
71+
for (uint32_t i = 0; i < msg->size_; i++) {
72+
dest[start_idx + i] = msg->val_[i];
73+
}
74+
}
75+
76+
template <template <typename Arg> class Op>
77+
static void reduce(
78+
std::vector<Scalar>& dest, size_t start_idx, RabenseifnerMsg<Scalar, DataT>* msg) {
79+
for (uint32_t i = 0; i < msg->size_; i++) {
80+
Op<Scalar>()(dest[start_idx + i], msg->val_[i]);
81+
}
82+
}
83+
84+
static void invoke() { }
85+
86+
static bool empty(const std::vector<Scalar>& payload) {
87+
return payload.empty();
88+
}
89+
};
90+
91+
#if MAGISTRATE_KOKKOS_ENABLED
92+
93+
template <typename Scalar>
94+
struct DataHelper<Scalar, Kokkos::View<Scalar*, Kokkos::HostSpace>> {
95+
using DataT = Kokkos::View<Scalar*, Kokkos::HostSpace>;
96+
using DataType = DataHandler<DataT>;
97+
98+
template <typename... Args>
99+
static void assign(DataT& dest, Args&&... data) {
100+
dest = {std::forward<Args>(data)...};
101+
}
102+
103+
static MsgPtr<RabenseifnerMsg<Scalar, DataT>> createMessage(
104+
const DataT& payload, size_t begin, size_t count, size_t id,
105+
int32_t step = 0) {
106+
return vt::makeMessage<RabenseifnerMsg<Scalar, DataT>>(
107+
Kokkos::subview(payload, std::make_pair(begin, begin + count)), id, step
108+
);
109+
}
110+
111+
static void
112+
copy(DataT& dest, size_t start_idx, RabenseifnerMsg<Scalar, DataT>* msg) {
113+
Kokkos::parallel_for(
114+
"Rabenseifner::copy", msg->val_.extent(0),
115+
KOKKOS_LAMBDA(const int i) { dest(start_idx + i) = msg->val_(i); }
116+
);
117+
}
118+
119+
template <template <typename Arg> class Op>
120+
static void reduce(
121+
DataT& dest, size_t start_idx, RabenseifnerMsg<Scalar, DataT>* msg) {
122+
Kokkos::parallel_for(
123+
"Rabenseifner::reduce", msg->val_.extent(0), KOKKOS_LAMBDA(const int i) {
124+
Op<Scalar>()(dest(start_idx + i), msg->val_(i));
125+
}
126+
);
127+
}
128+
129+
static void invoke() { }
130+
131+
static bool empty(const DataT& payload) {
132+
return payload.extent(0) == 0;
133+
}
134+
};
135+
136+
#endif // MAGISTRATE_KOKKOS_ENABLED
137+
138+
struct StateBase {
139+
size_t size_ = {};
140+
141+
bool finished_adjustment_part_ = false;
142+
143+
int32_t mask_ = 1;
144+
int32_t step_ = 0;
145+
bool initialized_ = false;
146+
bool completed_ = false;
147+
148+
// Scatter
149+
int32_t scatter_mask_ = 1;
150+
int32_t scatter_step_ = 0;
151+
int32_t scatter_num_recv_ = 0;
152+
std::vector<bool> scatter_steps_recv_ = {};
153+
std::vector<bool> scatter_steps_reduced_ = {};
154+
155+
bool finished_scatter_part_ = false;
156+
157+
// Gather
158+
int32_t gather_step_ = 0;
159+
int32_t gather_mask_ = 1;
160+
int32_t gather_num_recv_ = 0;
161+
std::vector<bool> gather_steps_recv_ = {};
162+
std::vector<bool> gather_steps_reduced_ = {};
163+
164+
std::vector<uint32_t> r_index_ = {};
165+
std::vector<uint32_t> r_count_ = {};
166+
std::vector<uint32_t> s_index_ = {};
167+
std::vector<uint32_t> s_count_ = {};
168+
};
169+
170+
template <typename Scalar, typename DataT>
171+
struct State : StateBase {
172+
std::vector<Scalar> val_ = {};
173+
174+
MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>> left_adjust_message_ = nullptr;
175+
MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>> right_adjust_message_ = nullptr;
176+
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> scatter_messages_ = {};
177+
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> gather_messages_ = {};
178+
};
179+
180+
#if MAGISTRATE_KOKKOS_ENABLED
181+
template <typename Scalar>
182+
struct State<Scalar, Kokkos::View<Scalar*, Kokkos::HostSpace>> : StateBase {
183+
using DataT = Kokkos::View<Scalar*, Kokkos::HostSpace>;
184+
185+
Kokkos::View<Scalar*, Kokkos::HostSpace> val_ = {};
186+
187+
MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>> left_adjust_message_ = nullptr;
188+
MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>> right_adjust_message_ = nullptr;
189+
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> scatter_messages_ = {};
190+
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> gather_messages_ = {};
191+
};
192+
#endif //MAGISTRATE_KOKKOS_ENABLED
193+
194+
} // namespace vt::collective::reduce::allreduce
195+
#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_HELPERS_H*/

0 commit comments

Comments
 (0)