Skip to content

Commit 6332510

Browse files
committed
#2240: Working allreduce perf test with Kokkos
1 parent fdecfb0 commit 6332510

File tree

10 files changed

+468
-162
lines changed

10 files changed

+468
-162
lines changed

src/vt/collective/reduce/allreduce/data_handler.h

+66-23
Original file line numberDiff line numberDiff line change
@@ -45,38 +45,60 @@
4545
#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_DATA_HANDLER_H
4646
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_DATA_HANDLER_H
4747

48-
namespace vt::collective::reduce::allreduce {
4948
#include <vector>
5049

5150
#ifdef VT_KOKKOS_ENABLED
5251
#include <Kokkos_Core.hpp>
5352
#endif
5453

55-
template <typename Container>
54+
namespace vt::collective::reduce::allreduce {
55+
56+
template <typename DataType, typename Enable = void>
5657
class DataHandler {
5758
public:
58-
using Scalar = float;
59+
using Scalar = void;
60+
};
61+
62+
template <typename Scalar>
63+
class DataHandler<Scalar, typename std::enable_if<std::is_arithmetic<Scalar>::value>::type> {
64+
public:
65+
using ScalarType = Scalar;
5966

60-
static size_t size(const Container& data);
61-
static Scalar& at(Container& data, size_t idx);
62-
static void set(Container& data, size_t idx, const Scalar& value);
63-
static Container split(Container& data, size_t start, size_t end);
67+
static std::vector<ScalarType> toVec(const ScalarType& data) { return std::vector<ScalarType>{data}; }
68+
static ScalarType fromVec(const std::vector<ScalarType>& data) { return data[0]; }
69+
static ScalarType fromMemory(ScalarType* data, size_t count) {
70+
return *data;
71+
}
72+
73+
// static const ScalarType* data(const ScalarType& data) { return &data; }
74+
// static size_t size(const ScalarType&) { return 1; }
75+
// static ScalarType& at(ScalarType& data, size_t) { return data; }
76+
// static void set(ScalarType& data, size_t, const ScalarType& value) { data = value; }
77+
// static ScalarType split(ScalarType&, size_t, size_t) { return ScalarType{}; }
6478
};
6579

6680
template <typename T>
6781
class DataHandler<std::vector<T>> {
6882
public:
6983
using UnderlyingType = std::vector<T>;
7084
using Scalar = T;
71-
static size_t size(const std::vector<T>& data) { return data.size(); }
72-
static T at(const std::vector<T>& data, size_t idx) { return data[idx]; }
73-
static T& at(std::vector<T>& data, size_t idx) { return data[idx]; }
74-
static void set(std::vector<T>& data, size_t idx, const T& value) {
75-
data[idx] = value;
76-
}
77-
static std::vector<T> split(std::vector<T>& data, size_t start, size_t end) {
78-
return std::vector<T>{data.begin() + start, data.begin() + end};
85+
86+
static const std::vector<T>& toVec(const std::vector<T>& data) { return data; }
87+
static std::vector<T> fromVec(const std::vector<T>& data) { return data; }
88+
static std::vector<T> fromMemory(T* data, size_t count) {
89+
return std::vector<T>(data, data + count);
7990
}
91+
92+
// static const T* data(const std::vector<T>& data) {return data.data(); }
93+
static size_t size(const std::vector<T>& data) { return data.size(); }
94+
// static T at(const std::vector<T>& data, size_t idx) { return data[idx]; }
95+
// static T& at(std::vector<T>& data, size_t idx) { return data[idx]; }
96+
// static void set(std::vector<T>& data, size_t idx, const T& value) {
97+
// data[idx] = value;
98+
// }
99+
// static std::vector<T> split(std::vector<T>& data, size_t start, size_t end) {
100+
// return std::vector<T>{data.begin() + start, data.begin() + end};
101+
// }
80102
};
81103

82104
#if KOKKOS_ENABLED_CHECKPOINT
@@ -88,19 +110,40 @@ class DataHandler<Kokkos::View<T*, Kokkos::HostSpace, Props...>> {
88110
public:
89111
using Scalar = T;
90112

91-
static size_t size(const ViewType& data) { return data.extent(0); }
113+
static std::vector<T> toVec(const ViewType& data) {
114+
std::vector<T> vec;
115+
vec.resize(data.extent(0));
116+
std::memcpy(vec.data(), data.data(), data.extent(0) * sizeof(T));
117+
return vec;
118+
}
92119

93-
static T at(const ViewType& data, size_t idx) { return data(idx); }
120+
static ViewType fromMemory(T* data, size_t size) {
121+
return ViewType(data, size);
122+
}
94123

95-
static T& at(ViewType& data, size_t idx) { return data(idx); }
124+
static ViewType fromVec(const std::vector<T>& data) {
125+
ViewType view("", data.size());
126+
Kokkos::parallel_for(
127+
"InitView", view.extent(0),
128+
KOKKOS_LAMBDA(const int i) { view(i) = static_cast<float>(data[i]); });
96129

97-
static void set(ViewType& data, size_t idx, const T& value) {
98-
data(idx) = value;
130+
return view;
99131
}
100132

101-
static ViewType split(ViewType& data, size_t start, size_t end) {
102-
return Kokkos::subview(data, std::make_pair(start, end));
103-
}
133+
// static const T* data(const ViewType& data) {return data.data(); }
134+
static size_t size(const ViewType& data) { return data.extent(0); }
135+
136+
// static T at(const ViewType& data, size_t idx) { return data(idx); }
137+
138+
// static T& at(ViewType& data, size_t idx) { return data(idx); }
139+
140+
// static void set(ViewType& data, size_t idx, const T& value) {
141+
// data(idx) = value;
142+
// }
143+
144+
// static ViewType split(ViewType& data, size_t start, size_t end) {
145+
// return Kokkos::subview(data, std::make_pair(start, end));
146+
// }
104147
};
105148

106149
#endif // KOKKOS_ENABLED_CHECKPOINT

src/vt/collective/reduce/allreduce/rabenseifner.h

+61-15
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,50 @@ struct AllreduceRbnMsg
8686
int32_t step_ = {};
8787
};
8888

89+
template <typename Scalar>
90+
struct AllreduceRbnRawMsg
91+
: Message {
92+
using MessageParentType = vt::Message;
93+
vt_msg_serialize_required();
94+
95+
96+
AllreduceRbnRawMsg() = default;
97+
AllreduceRbnRawMsg(AllreduceRbnRawMsg const&) = default;
98+
AllreduceRbnRawMsg(AllreduceRbnRawMsg&&) = default;
99+
~AllreduceRbnRawMsg() {
100+
if (owning_) {
101+
delete[] val_;
102+
}
103+
}
104+
105+
AllreduceRbnRawMsg(Scalar* in_val, size_t size, int step = 0)
106+
: MessageParentType(),
107+
val_(in_val),
108+
size_(size),
109+
step_(step) { }
110+
111+
template <typename SerializeT>
112+
void serialize(SerializeT& s) {
113+
MessageParentType::serialize(s);
114+
115+
s | size_;
116+
117+
if (s.isUnpacking()) {
118+
owning_ = true;
119+
val_ = new Scalar[size_];
120+
}
121+
122+
checkpoint::dispatch::serializeArray(s, val_, size_);
123+
124+
s | step_;
125+
}
126+
127+
Scalar* val_ = {};
128+
size_t size_ = {};
129+
int32_t step_ = {};
130+
bool owning_ = false;
131+
};
132+
89133
/**
90134
* \struct Rabenseifner
91135
* \brief Class implementing Rabenseifner's allreduce algorithm.
@@ -103,6 +147,7 @@ template <
103147
>
104148
struct Rabenseifner {
105149
using DataType = DataHandler<DataT>;
150+
using Scalar = typename DataType::Scalar;
106151

107152
/**
108153
* \brief Constructor for Rabenseifner's allreduce algorithm.
@@ -111,7 +156,7 @@ struct Rabenseifner {
111156
* \param num_nodes Total number of nodes involved in the allreduce operation.
112157
* \param args Additional arguments for initializing the data value.
113158
*/
114-
template <typename... Args>
159+
template <typename ...Args>
115160
Rabenseifner(
116161
vt::objgroup::proxy::Proxy<ObjT> parentProxy, NodeType num_nodes,
117162
Args&&... args);
@@ -123,7 +168,7 @@ struct Rabenseifner {
123168
*
124169
* \param args Additional arguments for initializing the data value.
125170
*/
126-
template <typename... Args>
171+
template <typename ...Args>
127172
void initialize(Args&&... args);
128173

129174
/**
@@ -153,7 +198,7 @@ struct Rabenseifner {
153198
*
154199
* \param msg Message containing the data from the partner process.
155200
*/
156-
void adjustForPowerOfTwoRightHalf(AllreduceRbnMsg<DataT>* msg);
201+
void adjustForPowerOfTwoRightHalf(AllreduceRbnRawMsg<Scalar>* msg);
157202

158203
/**
159204
* \brief Handler for adjusting the left half of the process group.
@@ -162,7 +207,7 @@ struct Rabenseifner {
162207
*
163208
* \param msg Message containing the data from the partner process.
164209
*/
165-
void adjustForPowerOfTwoLeftHalf(AllreduceRbnMsg<DataT>* msg);
210+
void adjustForPowerOfTwoLeftHalf(AllreduceRbnRawMsg<Scalar>* msg);
166211

167212
/**
168213
* \brief Final adjustment step for non-power-of-two process counts.
@@ -171,7 +216,7 @@ struct Rabenseifner {
171216
*
172217
* \param msg Message containing the data from the partner process.
173218
*/
174-
void adjustForPowerOfTwoFinalPart(AllreduceRbnMsg<DataT>* msg);
219+
void adjustForPowerOfTwoFinalPart(AllreduceRbnRawMsg<Scalar>* msg);
175220

176221
/**
177222
* \brief Check if all scatter messages have been received.
@@ -215,7 +260,7 @@ struct Rabenseifner {
215260
*
216261
* \param msg Message containing the data from the partner process.
217262
*/
218-
void scatterReduceIterHandler(AllreduceRbnMsg<DataT>* msg);
263+
void scatterReduceIterHandler(AllreduceRbnRawMsg<Scalar>* msg);
219264

220265
/**
221266
* \brief Check if all gather messages have been received.
@@ -259,7 +304,7 @@ struct Rabenseifner {
259304
*
260305
* \param msg Message containing the data from the partner process.
261306
*/
262-
void gatherIterHandler(AllreduceRbnMsg<DataT>* msg);
307+
void gatherIterHandler(AllreduceRbnRawMsg<Scalar>* msg);
263308

264309
/**
265310
* \brief Perform the final part of the allreduce operation.
@@ -282,12 +327,13 @@ struct Rabenseifner {
282327
*
283328
* \param msg Message containing the final result.
284329
*/
285-
void sendToExcludedNodesHandler(AllreduceRbnMsg<DataT>* msg);
330+
void sendToExcludedNodesHandler(AllreduceRbnRawMsg<Scalar>* msg);
286331

287332
vt::objgroup::proxy::Proxy<Rabenseifner> proxy_ = {};
288333
vt::objgroup::proxy::Proxy<ObjT> parent_proxy_ = {};
289334

290-
DataT val_ = {};
335+
// DataT val_ = {};
336+
std::vector<Scalar> val_;
291337
size_t size_ = {};
292338
NodeType num_nodes_ = {};
293339
NodeType this_node_ = {};
@@ -297,10 +343,10 @@ struct Rabenseifner {
297343
int32_t nprocs_pof2_ = {};
298344
int32_t nprocs_rem_ = {};
299345

300-
std::vector<int32_t> r_index_ = {};
301-
std::vector<int32_t> r_count_ = {};
302-
std::vector<int32_t> s_index_ = {};
303-
std::vector<int32_t> s_count_ = {};
346+
std::vector<uint32_t> r_index_ = {};
347+
std::vector<uint32_t> r_count_ = {};
348+
std::vector<uint32_t> s_index_ = {};
349+
std::vector<uint32_t> s_count_ = {};
304350

305351
NodeType vrt_node_ = {};
306352
bool is_part_of_adjustment_group_ = false;
@@ -314,7 +360,7 @@ struct Rabenseifner {
314360
int32_t scatter_num_recv_ = 0;
315361
std::vector<bool> scatter_steps_recv_ = {};
316362
std::vector<bool> scatter_steps_reduced_ = {};
317-
std::vector<MsgSharedPtr<AllreduceRbnMsg<DataT>>> scatter_messages_ = {};
363+
std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> scatter_messages_ = {};
318364
bool finished_scatter_part_ = false;
319365

320366
// Gather
@@ -323,7 +369,7 @@ struct Rabenseifner {
323369
int32_t gather_num_recv_ = 0;
324370
std::vector<bool> gather_steps_recv_ = {};
325371
std::vector<bool> gather_steps_reduced_ = {};
326-
std::vector<MsgSharedPtr<AllreduceRbnMsg<DataT>>> gather_messages_ = {};
372+
std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> gather_messages_ = {};
327373
};
328374

329375
} // namespace vt::collective::reduce::allreduce

0 commit comments

Comments
 (0)