Skip to content

Commit 31831c9

Browse files
committed
#2240: Working RecursiveDoubling with multiple allreduce in flight
1 parent 6332510 commit 31831c9

File tree

3 files changed

+247
-188
lines changed

3 files changed

+247
-188
lines changed

src/vt/collective/reduce/allreduce/recursive_doubling.h

+48-68
Original file line numberDiff line numberDiff line change
@@ -58,36 +58,6 @@
5858
namespace vt::collective::reduce::allreduce {
5959

6060
template <typename DataT>
61-
struct AllreduceDblMsg
62-
: SerializeIfNeeded<vt::Message, AllreduceDblMsg<DataT>, DataT> {
63-
using MessageParentType =
64-
SerializeIfNeeded<::vt::Message, AllreduceDblMsg<DataT>, DataT>;
65-
66-
AllreduceDblMsg() = default;
67-
AllreduceDblMsg(AllreduceDblMsg const&) = default;
68-
AllreduceDblMsg(AllreduceDblMsg&&) = default;
69-
70-
AllreduceDblMsg(DataT&& in_val, int step = 0)
71-
: MessageParentType(),
72-
val_(std::forward<DataT>(in_val)),
73-
step_(step) { }
74-
AllreduceDblMsg(DataT const& in_val, int step = 0)
75-
: MessageParentType(),
76-
val_(in_val),
77-
step_(step) { }
78-
79-
template <typename SerializeT>
80-
void serialize(SerializeT& s) {
81-
MessageParentType::serialize(s);
82-
s | val_;
83-
s | step_;
84-
}
85-
86-
DataT val_ = {};
87-
int32_t step_ = {};
88-
};
89-
90-
template <typename Scalar>
9161
struct AllreduceDblRawMsg
9262
: Message {
9363
using MessageParentType = vt::Message;
@@ -99,34 +69,32 @@ struct AllreduceDblRawMsg
9969
AllreduceDblRawMsg(AllreduceDblRawMsg&&) = default;
10070
~AllreduceDblRawMsg() {
10171
if (owning_) {
102-
delete[] val_;
72+
delete val_;
10373
}
10474
}
10575

106-
AllreduceDblRawMsg(std::vector<Scalar>& in_val, int step = 0)
76+
AllreduceDblRawMsg(DataT const& in_val, size_t id, int step = 0)
10777
: MessageParentType(),
108-
val_(in_val.data()),
109-
size_(in_val.size()),
78+
val_(&in_val),
79+
id_(id),
11080
step_(step) { }
11181

11282
template <typename SerializeT>
11383
void serialize(SerializeT& s) {
11484
MessageParentType::serialize(s);
11585

116-
s | size_;
117-
11886
if (s.isUnpacking()) {
11987
owning_ = true;
120-
val_ = new Scalar[size_];
88+
val_ = new DataT();
12189
}
12290

123-
checkpoint::dispatch::serializeArray(s, val_, size_);
124-
91+
s | *val_;
92+
s | id_;
12593
s | step_;
12694
}
12795

128-
Scalar* val_ = {};
129-
size_t size_ = {};
96+
const DataT* val_ = {};
97+
size_t id_ = {};
13098
int32_t step_ = {};
13199
bool owning_ = false;
132100
};
@@ -158,103 +126,126 @@ struct RecursiveDoubling {
158126
* \param num_nodes The number of nodes.
159127
* \param args Additional arguments for data initialization.
160128
*/
129+
template <typename... Args>
161130
RecursiveDoubling(
162131
vt::objgroup::proxy::Proxy<ObjT> parentProxy, NodeType num_nodes,
163-
const DataT& data);
132+
Args&&... data);
164133

165134
/**
166135
* \brief Start the allreduce operation.
167136
*/
168-
void allreduce();
137+
void allreduce(size_t id);
169138

170139
/**
171140
* \brief Initialize the RecursiveDoubling object.
172141
*
173142
* \param args Additional arguments for data initialization.
174143
*/
175-
void initialize(const DataT& data);
144+
template <typename... Args>
145+
void initialize(size_t id, Args&&... data);
146+
void initializeState(size_t id);
147+
148+
size_t generateNewId() { return id_++; }
176149

177150
/**
178151
* \brief Adjust for power of two nodes.
179152
*/
180-
void adjustForPowerOfTwo();
153+
void adjustForPowerOfTwo(size_t id);
181154

182155
/**
183156
* \brief Handler for adjusting for power of two nodes.
184157
*
185158
* \param msg Pointer to the message.
186159
*/
187-
void adjustForPowerOfTwoHandler(AllreduceDblRawMsg<Scalar>* msg);
160+
void adjustForPowerOfTwoHandler(AllreduceDblRawMsg<DataT>* msg);
188161

189162
/**
190163
* \brief Check if the allreduce operation is done.
191164
*
192165
* \return True if the operation is done, otherwise false.
193166
*/
194-
bool done();
167+
bool isDone(size_t id);
195168

196169
/**
197170
* \brief Check if the current state is valid for allreduce.
198171
*
199172
* \return True if the state is valid, otherwise false.
200173
*/
201-
bool isValid();
174+
bool isValid(size_t id);
202175

203176
/**
204177
* \brief Check if all messages are received for the current step.
205178
*
206179
* \return True if all messages are received, otherwise false.
207180
*/
208-
bool allMessagesReceived();
181+
bool allMessagesReceived(size_t id);
209182

210183
/**
211184
* \brief Check if the object is ready for the next step of allreduce.
212185
*
213186
* \return True if ready, otherwise false.
214187
*/
215-
bool isReady();
188+
bool isReady(size_t id);
216189

217190
/**
218191
* \brief Perform the next step of the allreduce operation.
219192
*/
220-
void reduceIter();
193+
void reduceIter(size_t id);
221194

222195
/**
223196
* \brief Try to reduce the message at the specified step.
224197
*
225198
* \param step The step at which to try reduction.
226199
*/
227-
void tryReduce(int32_t step);
200+
void tryReduce(size_t id, int32_t step);
228201

229202
/**
230203
* \brief Handler for the reduce iteration.
231204
*
232205
* \param msg Pointer to the message.
233206
*/
234-
void reduceIterHandler(AllreduceDblRawMsg<Scalar>* msg);
207+
void reduceIterHandler(AllreduceDblRawMsg<DataT>* msg);
235208

236209
/**
237210
* \brief Send data to excluded nodes for finalization.
238211
*/
239-
void sendToExcludedNodes();
212+
void sendToExcludedNodes(size_t id);
240213

241214
/**
242215
* \brief Handler for sending data to excluded nodes.
243216
*
244217
* \param msg Pointer to the message.
245218
*/
246-
void sendToExcludedNodesHandler(AllreduceDblRawMsg<Scalar>* msg);
219+
void sendToExcludedNodesHandler(AllreduceDblRawMsg<DataT>* msg);
247220

248221
/**
249222
* \brief Perform the final part of the allreduce operation.
250223
*/
251-
void finalPart();
224+
void finalPart(size_t id);
252225

253226
vt::objgroup::proxy::Proxy<RecursiveDoubling> proxy_ = {};
254227
vt::objgroup::proxy::Proxy<ObjT> parent_proxy_ = {};
255228

256229
// DataT val_ = {};
257-
std::vector<Scalar> val_;
230+
231+
struct State{
232+
DataT val_ = {};
233+
bool finished_adjustment_part_ = false;
234+
MsgSharedPtr<AllreduceDblRawMsg<DataT>> adjust_message_ = nullptr;
235+
236+
int32_t mask_ = 1;
237+
int32_t step_ = 0;
238+
bool initialized_ = false;
239+
bool completed_ = false;
240+
241+
std::vector<bool> steps_recv_ = {};
242+
std::vector<bool> steps_reduced_ = {};
243+
std::vector<MsgSharedPtr<AllreduceDblRawMsg<DataT>>> messages_ = {};
244+
};
245+
246+
size_t id_ = 0;
247+
std::unordered_map<size_t, State> states_ = {};
248+
258249
NodeType num_nodes_ = {};
259250
NodeType this_node_ = {};
260251

@@ -265,17 +256,6 @@ struct RecursiveDoubling {
265256

266257
NodeType vrt_node_ = {};
267258
bool is_part_of_adjustment_group_ = false;
268-
bool finished_adjustment_part_ = false;
269-
270-
int32_t mask_ = 1;
271-
int32_t step_ = 0;
272-
273-
bool completed_ = false;
274-
275-
std::vector<bool> steps_recv_ = {};
276-
std::vector<bool> steps_reduced_ = {};
277-
278-
std::vector<MsgSharedPtr<AllreduceDblRawMsg<Scalar>>> messages_ = {};
279259
};
280260

281261
} // namespace vt::collective::reduce::allreduce

0 commit comments

Comments
 (0)