Skip to content

Commit 5ce7cd9

Browse files
committed
#2240: Update Rabenseifner to use ID for each allreduce and update tests
1 parent 54d4bef commit 5ce7cd9

11 files changed

+409
-355
lines changed

src/vt/collective/reduce/allreduce/data_handler.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -57,21 +57,22 @@ template <typename DataType, typename Enable = void>
5757
class DataHandler {
5858
public:
5959
using Scalar = void;
60+
static size_t size(void) { return 0; }
6061
};
6162

62-
template <typename Scalar>
63-
class DataHandler<Scalar, typename std::enable_if<std::is_arithmetic<Scalar>::value>::type> {
63+
template <typename ScalarType>
64+
class DataHandler<ScalarType, typename std::enable_if<std::is_arithmetic<ScalarType>::value>::type> {
6465
public:
65-
using ScalarType = Scalar;
66+
using Scalar = ScalarType;
6667

6768
static std::vector<ScalarType> toVec(const ScalarType& data) { return std::vector<ScalarType>{data}; }
6869
static ScalarType fromVec(const std::vector<ScalarType>& data) { return data[0]; }
69-
static ScalarType fromMemory(ScalarType* data, size_t count) {
70+
static ScalarType fromMemory(ScalarType* data, size_t) {
7071
return *data;
7172
}
7273

7374
// static const ScalarType* data(const ScalarType& data) { return &data; }
74-
// static size_t size(const ScalarType&) { return 1; }
75+
static size_t size(const ScalarType&) { return 1; }
7576
// static ScalarType& at(ScalarType& data, size_t) { return data; }
7677
// static void set(ScalarType& data, size_t, const ScalarType& value) { data = value; }
7778
// static ScalarType split(ScalarType&, size_t, size_t) { return ScalarType{}; }
@@ -80,7 +81,6 @@ class DataHandler<Scalar, typename std::enable_if<std::is_arithmetic<Scalar>::va
8081
template <typename T>
8182
class DataHandler<std::vector<T>> {
8283
public:
83-
using UnderlyingType = std::vector<T>;
8484
using Scalar = T;
8585

8686
static const std::vector<T>& toVec(const std::vector<T>& data) { return data; }

src/vt/collective/reduce/allreduce/rabenseifner.h

+63-75
Original file line numberDiff line numberDiff line change
@@ -56,36 +56,6 @@
5656

5757
namespace vt::collective::reduce::allreduce {
5858

59-
template <typename DataT>
60-
struct AllreduceRbnMsg
61-
: SerializeIfNeeded<vt::Message, AllreduceRbnMsg<DataT>, DataT> {
62-
using MessageParentType =
63-
SerializeIfNeeded<::vt::Message, AllreduceRbnMsg<DataT>, DataT>;
64-
65-
AllreduceRbnMsg() = default;
66-
AllreduceRbnMsg(AllreduceRbnMsg const&) = default;
67-
AllreduceRbnMsg(AllreduceRbnMsg&&) = default;
68-
69-
AllreduceRbnMsg(DataT&& in_val, int step = 0)
70-
: MessageParentType(),
71-
val_(std::forward<DataT>(in_val)),
72-
step_(step) { }
73-
AllreduceRbnMsg(DataT const& in_val, int step = 0)
74-
: MessageParentType(),
75-
val_(in_val),
76-
step_(step) { }
77-
78-
template <typename SerializeT>
79-
void serialize(SerializeT& s) {
80-
MessageParentType::serialize(s);
81-
s | val_;
82-
s | step_;
83-
}
84-
85-
DataT val_ = {};
86-
int32_t step_ = {};
87-
};
88-
8959
template <typename Scalar>
9060
struct AllreduceRbnRawMsg
9161
: Message {
@@ -102,10 +72,11 @@ struct AllreduceRbnRawMsg
10272
}
10373
}
10474

105-
AllreduceRbnRawMsg(Scalar* in_val, size_t size, int step = 0)
75+
AllreduceRbnRawMsg(Scalar* in_val, size_t size, size_t id, int step = 0)
10676
: MessageParentType(),
10777
val_(in_val),
10878
size_(size),
79+
id_(id),
10980
step_(step) { }
11081

11182
template <typename SerializeT>
@@ -121,11 +92,13 @@ struct AllreduceRbnRawMsg
12192

12293
checkpoint::dispatch::serializeArray(s, val_, size_);
12394

95+
s | id_;
12496
s | step_;
12597
}
12698

12799
Scalar* val_ = {};
128100
size_t size_ = {};
101+
size_t id_ = {};
129102
int32_t step_ = {};
130103
bool owning_ = false;
131104
};
@@ -169,27 +142,30 @@ struct Rabenseifner {
169142
* \param args Additional arguments for initializing the data value.
170143
*/
171144
template <typename ...Args>
172-
void initialize(Args&&... args);
145+
void initialize(size_t id, Args&&... args);
146+
147+
void initializeState(size_t id);
148+
size_t generateNewId() { return id_++; }
173149

174150
/**
175151
* \brief Execute the final handler callback with the reduced result.
176152
*/
177-
void executeFinalHan();
153+
void executeFinalHan(size_t id);
178154

179155
/**
180156
* \brief Perform the allreduce operation.
181157
*
182158
* This function starts the allreduce operation, adjusting for non-power-of-two process counts if necessary.
183159
*/
184-
void allreduce();
160+
void allreduce(size_t id);
185161

186162
/**
187163
* \brief Adjust the process count to the nearest power-of-two.
188164
*
189165
* This function performs additional steps to handle non-power-of-two process counts, ensuring that the
190166
* main scatter-reduce and gather-allgather phases can proceed with a power-of-two number of processes.
191167
*/
192-
void adjustForPowerOfTwo();
168+
void adjustForPowerOfTwo(size_t id);
193169

194170
/**
195171
* \brief Handler for adjusting the right half of the process group.
@@ -223,35 +199,35 @@ struct Rabenseifner {
223199
*
224200
* \return True if all scatter messages have been received, false otherwise.
225201
*/
226-
bool scatterAllMessagesReceived();
202+
bool scatterAllMessagesReceived(size_t id);
227203

228204
/**
229205
* \brief Check if the scatter phase is complete.
230206
*
231207
* \return True if the scatter phase is complete, false otherwise.
232208
*/
233-
bool scatterIsDone();
209+
bool scatterIsDone(size_t id);
234210

235211
/**
236212
* \brief Check if the scatter phase is ready to proceed.
237213
*
238214
* \return True if the scatter phase is ready to proceed, false otherwise.
239215
*/
240-
bool scatterIsReady();
216+
bool scatterIsReady(size_t id);
241217

242218
/**
243219
* \brief Try to reduce the received scatter messages.
244220
*
245221
* \param step The current step in the scatter phase.
246222
*/
247-
void scatterTryReduce(int32_t step);
223+
void scatterTryReduce(size_t id, int32_t step);
248224

249225
/**
250226
* \brief Perform the scatter-reduce iteration.
251227
*
252228
* This function sends data to the appropriate partner process and proceeds to the next step in the scatter phase.
253229
*/
254-
void scatterReduceIter();
230+
void scatterReduceIter(size_t id);
255231

256232
/**
257233
* \brief Handler for receiving scatter-reduce messages.
@@ -267,35 +243,35 @@ struct Rabenseifner {
267243
*
268244
* \return True if all gather messages have been received, false otherwise.
269245
*/
270-
bool gatherAllMessagesReceived();
246+
bool gatherAllMessagesReceived(size_t id);
271247

272248
/**
273249
* \brief Check if the gather phase is complete.
274250
*
275251
* \return True if the gather phase is complete, false otherwise.
276252
*/
277-
bool gatherIsDone();
253+
bool gatherIsDone(size_t id);
278254

279255
/**
280256
* \brief Check if the gather phase is ready to proceed.
281257
*
282258
* \return True if the gather phase is ready to proceed, false otherwise.
283259
*/
284-
bool gatherIsReady();
260+
bool gatherIsReady(size_t id);
285261

286262
/**
287263
* \brief Try to reduce the received gather messages.
288264
*
289265
* \param step The current step in the gather phase.
290266
*/
291-
void gatherTryReduce(int32_t step);
267+
void gatherTryReduce(size_t id, int32_t step);
292268

293269
/**
294270
* \brief Perform the gather iteration.
295271
*
296272
* This function sends data to the appropriate partner process and proceeds to the next step in the gather phase.
297273
*/
298-
void gatherIter();
274+
void gatherIter(size_t id);
299275

300276
/**
301277
* \brief Handler for receiving gather messages.
@@ -311,14 +287,14 @@ struct Rabenseifner {
311287
*
312288
* This function completes the allreduce operation, handling any remaining steps and invoking the final handler.
313289
*/
314-
void finalPart();
290+
void finalPart(size_t id);
315291

316292
/**
317293
* \brief Send the result to excluded nodes.
318294
*
319295
* This function handles the final step for non-power-of-two process counts, sending the reduced result to excluded nodes.
320296
*/
321-
void sendToExcludedNodes();
297+
void sendToExcludedNodes(size_t id);
322298

323299
/**
324300
* \brief Handler for receiving the final result on excluded nodes.
@@ -332,9 +308,46 @@ struct Rabenseifner {
332308
vt::objgroup::proxy::Proxy<Rabenseifner> proxy_ = {};
333309
vt::objgroup::proxy::Proxy<ObjT> parent_proxy_ = {};
334310

335-
// DataT val_ = {};
336-
std::vector<Scalar> val_;
337-
size_t size_ = {};
311+
struct State {
312+
std::vector<Scalar> val_ = {};
313+
size_t size_ = {};
314+
315+
bool finished_adjustment_part_ = false;
316+
MsgSharedPtr<AllreduceRbnRawMsg<Scalar>> left_adjust_message_ = nullptr;
317+
MsgSharedPtr<AllreduceRbnRawMsg<Scalar>> right_adjust_message_ = nullptr;
318+
319+
int32_t mask_ = 1;
320+
int32_t step_ = 0;
321+
bool initialized_ = false;
322+
bool completed_ = false;
323+
324+
// Scatter
325+
int32_t scatter_mask_ = 1;
326+
int32_t scatter_step_ = 0;
327+
int32_t scatter_num_recv_ = 0;
328+
std::vector<bool> scatter_steps_recv_ = {};
329+
std::vector<bool> scatter_steps_reduced_ = {};
330+
std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> scatter_messages_ =
331+
{};
332+
bool finished_scatter_part_ = false;
333+
334+
// Gather
335+
int32_t gather_step_ = 0;
336+
int32_t gather_mask_ = 1;
337+
int32_t gather_num_recv_ = 0;
338+
std::vector<bool> gather_steps_recv_ = {};
339+
std::vector<bool> gather_steps_reduced_ = {};
340+
std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> gather_messages_ =
341+
{};
342+
343+
std::vector<uint32_t> r_index_ = {};
344+
std::vector<uint32_t> r_count_ = {};
345+
std::vector<uint32_t> s_index_ = {};
346+
std::vector<uint32_t> s_count_ = {};
347+
};
348+
349+
size_t id_ = 0;
350+
std::unordered_map<size_t, State> states_ = {};
338351
NodeType num_nodes_ = {};
339352
NodeType this_node_ = {};
340353

@@ -343,33 +356,8 @@ struct Rabenseifner {
343356
int32_t nprocs_pof2_ = {};
344357
int32_t nprocs_rem_ = {};
345358

346-
std::vector<uint32_t> r_index_ = {};
347-
std::vector<uint32_t> r_count_ = {};
348-
std::vector<uint32_t> s_index_ = {};
349-
std::vector<uint32_t> s_count_ = {};
350-
351359
NodeType vrt_node_ = {};
352360
bool is_part_of_adjustment_group_ = false;
353-
bool finished_adjustment_part_ = false;
354-
355-
bool completed_ = false;
356-
357-
// Scatter
358-
int32_t scatter_mask_ = 1;
359-
int32_t scatter_step_ = 0;
360-
int32_t scatter_num_recv_ = 0;
361-
std::vector<bool> scatter_steps_recv_ = {};
362-
std::vector<bool> scatter_steps_reduced_ = {};
363-
std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> scatter_messages_ = {};
364-
bool finished_scatter_part_ = false;
365-
366-
// Gather
367-
int32_t gather_step_ = 0;
368-
int32_t gather_mask_ = 1;
369-
int32_t gather_num_recv_ = 0;
370-
std::vector<bool> gather_steps_recv_ = {};
371-
std::vector<bool> gather_steps_reduced_ = {};
372-
std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> gather_messages_ = {};
373361
};
374362

375363
} // namespace vt::collective::reduce::allreduce

0 commit comments

Comments
 (0)