56
56
57
57
namespace vt ::collective::reduce::allreduce {
58
58
59
- template <typename DataT>
60
- struct AllreduceRbnMsg
61
- : SerializeIfNeeded<vt::Message, AllreduceRbnMsg<DataT>, DataT> {
62
- using MessageParentType =
63
- SerializeIfNeeded<::vt::Message, AllreduceRbnMsg<DataT>, DataT>;
64
-
65
- AllreduceRbnMsg () = default ;
66
- AllreduceRbnMsg (AllreduceRbnMsg const &) = default ;
67
- AllreduceRbnMsg (AllreduceRbnMsg&&) = default ;
68
-
69
- AllreduceRbnMsg (DataT&& in_val, int step = 0 )
70
- : MessageParentType(),
71
- val_ (std::forward<DataT>(in_val)),
72
- step_(step) { }
73
- AllreduceRbnMsg (DataT const & in_val, int step = 0 )
74
- : MessageParentType(),
75
- val_(in_val),
76
- step_(step) { }
77
-
78
- template <typename SerializeT>
79
- void serialize (SerializeT& s) {
80
- MessageParentType::serialize (s);
81
- s | val_;
82
- s | step_;
83
- }
84
-
85
- DataT val_ = {};
86
- int32_t step_ = {};
87
- };
88
-
89
59
template <typename Scalar>
90
60
struct AllreduceRbnRawMsg
91
61
: Message {
@@ -102,10 +72,11 @@ struct AllreduceRbnRawMsg
102
72
}
103
73
}
104
74
105
- AllreduceRbnRawMsg (Scalar* in_val, size_t size, int step = 0 )
75
+ AllreduceRbnRawMsg (Scalar* in_val, size_t size, size_t id, int step = 0 )
106
76
: MessageParentType(),
107
77
val_ (in_val),
108
78
size_(size),
79
+ id_(id),
109
80
step_(step) { }
110
81
111
82
template <typename SerializeT>
@@ -121,11 +92,13 @@ struct AllreduceRbnRawMsg
121
92
122
93
checkpoint::dispatch::serializeArray (s, val_, size_);
123
94
95
+ s | id_;
124
96
s | step_;
125
97
}
126
98
127
99
Scalar* val_ = {};
128
100
size_t size_ = {};
101
+ size_t id_ = {};
129
102
int32_t step_ = {};
130
103
bool owning_ = false ;
131
104
};
@@ -169,27 +142,30 @@ struct Rabenseifner {
169
142
* \param args Additional arguments for initializing the data value.
170
143
*/
171
144
template <typename ...Args>
172
- void initialize (Args&&... args);
145
+ void initialize (size_t id, Args&&... args);
146
+
147
+ void initializeState (size_t id);
148
+ size_t generateNewId () { return id_++; }
173
149
174
150
/* *
175
151
* \brief Execute the final handler callback with the reduced result.
176
152
*/
177
- void executeFinalHan ();
153
+ void executeFinalHan (size_t id );
178
154
179
155
/* *
180
156
* \brief Perform the allreduce operation.
181
157
*
182
158
* This function starts the allreduce operation, adjusting for non-power-of-two process counts if necessary.
183
159
*/
184
- void allreduce ();
160
+ void allreduce (size_t id );
185
161
186
162
/* *
187
163
* \brief Adjust the process count to the nearest power-of-two.
188
164
*
189
165
* This function performs additional steps to handle non-power-of-two process counts, ensuring that the
190
166
* main scatter-reduce and gather-allgather phases can proceed with a power-of-two number of processes.
191
167
*/
192
- void adjustForPowerOfTwo ();
168
+ void adjustForPowerOfTwo (size_t id );
193
169
194
170
/* *
195
171
* \brief Handler for adjusting the right half of the process group.
@@ -223,35 +199,35 @@ struct Rabenseifner {
223
199
*
224
200
* \return True if all scatter messages have been received, false otherwise.
225
201
*/
226
- bool scatterAllMessagesReceived ();
202
+ bool scatterAllMessagesReceived (size_t id );
227
203
228
204
/* *
229
205
* \brief Check if the scatter phase is complete.
230
206
*
231
207
* \return True if the scatter phase is complete, false otherwise.
232
208
*/
233
- bool scatterIsDone ();
209
+ bool scatterIsDone (size_t id );
234
210
235
211
/* *
236
212
* \brief Check if the scatter phase is ready to proceed.
237
213
*
238
214
* \return True if the scatter phase is ready to proceed, false otherwise.
239
215
*/
240
- bool scatterIsReady ();
216
+ bool scatterIsReady (size_t id );
241
217
242
218
/* *
243
219
* \brief Try to reduce the received scatter messages.
244
220
*
245
221
* \param step The current step in the scatter phase.
246
222
*/
247
- void scatterTryReduce (int32_t step);
223
+ void scatterTryReduce (size_t id, int32_t step);
248
224
249
225
/* *
250
226
* \brief Perform the scatter-reduce iteration.
251
227
*
252
228
* This function sends data to the appropriate partner process and proceeds to the next step in the scatter phase.
253
229
*/
254
- void scatterReduceIter ();
230
+ void scatterReduceIter (size_t id );
255
231
256
232
/* *
257
233
* \brief Handler for receiving scatter-reduce messages.
@@ -267,35 +243,35 @@ struct Rabenseifner {
267
243
*
268
244
* \return True if all gather messages have been received, false otherwise.
269
245
*/
270
- bool gatherAllMessagesReceived ();
246
+ bool gatherAllMessagesReceived (size_t id );
271
247
272
248
/* *
273
249
* \brief Check if the gather phase is complete.
274
250
*
275
251
* \return True if the gather phase is complete, false otherwise.
276
252
*/
277
- bool gatherIsDone ();
253
+ bool gatherIsDone (size_t id );
278
254
279
255
/* *
280
256
* \brief Check if the gather phase is ready to proceed.
281
257
*
282
258
* \return True if the gather phase is ready to proceed, false otherwise.
283
259
*/
284
- bool gatherIsReady ();
260
+ bool gatherIsReady (size_t id );
285
261
286
262
/* *
287
263
* \brief Try to reduce the received gather messages.
288
264
*
289
265
* \param step The current step in the gather phase.
290
266
*/
291
- void gatherTryReduce (int32_t step);
267
+ void gatherTryReduce (size_t id, int32_t step);
292
268
293
269
/* *
294
270
* \brief Perform the gather iteration.
295
271
*
296
272
* This function sends data to the appropriate partner process and proceeds to the next step in the gather phase.
297
273
*/
298
- void gatherIter ();
274
+ void gatherIter (size_t id );
299
275
300
276
/* *
301
277
* \brief Handler for receiving gather messages.
@@ -311,14 +287,14 @@ struct Rabenseifner {
311
287
*
312
288
* This function completes the allreduce operation, handling any remaining steps and invoking the final handler.
313
289
*/
314
- void finalPart ();
290
+ void finalPart (size_t id );
315
291
316
292
/* *
317
293
* \brief Send the result to excluded nodes.
318
294
*
319
295
* This function handles the final step for non-power-of-two process counts, sending the reduced result to excluded nodes.
320
296
*/
321
- void sendToExcludedNodes ();
297
+ void sendToExcludedNodes (size_t id );
322
298
323
299
/* *
324
300
* \brief Handler for receiving the final result on excluded nodes.
@@ -332,9 +308,46 @@ struct Rabenseifner {
332
308
vt::objgroup::proxy::Proxy<Rabenseifner> proxy_ = {};
333
309
vt::objgroup::proxy::Proxy<ObjT> parent_proxy_ = {};
334
310
335
- // DataT val_ = {};
336
- std::vector<Scalar> val_;
337
- size_t size_ = {};
311
+ struct State {
312
+ std::vector<Scalar> val_ = {};
313
+ size_t size_ = {};
314
+
315
+ bool finished_adjustment_part_ = false ;
316
+ MsgSharedPtr<AllreduceRbnRawMsg<Scalar>> left_adjust_message_ = nullptr ;
317
+ MsgSharedPtr<AllreduceRbnRawMsg<Scalar>> right_adjust_message_ = nullptr ;
318
+
319
+ int32_t mask_ = 1 ;
320
+ int32_t step_ = 0 ;
321
+ bool initialized_ = false ;
322
+ bool completed_ = false ;
323
+
324
+ // Scatter
325
+ int32_t scatter_mask_ = 1 ;
326
+ int32_t scatter_step_ = 0 ;
327
+ int32_t scatter_num_recv_ = 0 ;
328
+ std::vector<bool > scatter_steps_recv_ = {};
329
+ std::vector<bool > scatter_steps_reduced_ = {};
330
+ std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> scatter_messages_ =
331
+ {};
332
+ bool finished_scatter_part_ = false ;
333
+
334
+ // Gather
335
+ int32_t gather_step_ = 0 ;
336
+ int32_t gather_mask_ = 1 ;
337
+ int32_t gather_num_recv_ = 0 ;
338
+ std::vector<bool > gather_steps_recv_ = {};
339
+ std::vector<bool > gather_steps_reduced_ = {};
340
+ std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> gather_messages_ =
341
+ {};
342
+
343
+ std::vector<uint32_t > r_index_ = {};
344
+ std::vector<uint32_t > r_count_ = {};
345
+ std::vector<uint32_t > s_index_ = {};
346
+ std::vector<uint32_t > s_count_ = {};
347
+ };
348
+
349
+ size_t id_ = 0 ;
350
+ std::unordered_map<size_t , State> states_ = {};
338
351
NodeType num_nodes_ = {};
339
352
NodeType this_node_ = {};
340
353
@@ -343,33 +356,8 @@ struct Rabenseifner {
343
356
int32_t nprocs_pof2_ = {};
344
357
int32_t nprocs_rem_ = {};
345
358
346
- std::vector<uint32_t > r_index_ = {};
347
- std::vector<uint32_t > r_count_ = {};
348
- std::vector<uint32_t > s_index_ = {};
349
- std::vector<uint32_t > s_count_ = {};
350
-
351
359
NodeType vrt_node_ = {};
352
360
bool is_part_of_adjustment_group_ = false ;
353
- bool finished_adjustment_part_ = false ;
354
-
355
- bool completed_ = false ;
356
-
357
- // Scatter
358
- int32_t scatter_mask_ = 1 ;
359
- int32_t scatter_step_ = 0 ;
360
- int32_t scatter_num_recv_ = 0 ;
361
- std::vector<bool > scatter_steps_recv_ = {};
362
- std::vector<bool > scatter_steps_reduced_ = {};
363
- std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> scatter_messages_ = {};
364
- bool finished_scatter_part_ = false ;
365
-
366
- // Gather
367
- int32_t gather_step_ = 0 ;
368
- int32_t gather_mask_ = 1 ;
369
- int32_t gather_num_recv_ = 0 ;
370
- std::vector<bool > gather_steps_recv_ = {};
371
- std::vector<bool > gather_steps_reduced_ = {};
372
- std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> gather_messages_ = {};
373
361
};
374
362
375
363
} // namespace vt::collective::reduce::allreduce
0 commit comments