@@ -74,7 +74,6 @@ Rabenseifner<DataT, Op, ObjT, finalHandler>::Rabenseifner(
74
74
vrt_node_ = this_node_ - nprocs_rem_;
75
75
}
76
76
77
- vt_debug_print (terse, allreduce, " Rabenseifner constructor\n " );
78
77
initialize (generateNewId (), std::forward<Args>(data)...);
79
78
}
80
79
@@ -165,9 +164,9 @@ template <
165
164
typename DataT, template <typename Arg> class Op , typename ObjT, auto finalHandler
166
165
>
167
166
void Rabenseifner<DataT, Op, ObjT, finalHandler>::executeFinalHan(size_t id) {
168
- // theCB()->makeSend<finalHandler>(parent_proxy_[this_node_]).sendTuple(std::make_tuple(val_));
169
167
auto & state = states_.at (id);
170
168
vt_debug_print (terse, allreduce, " Rabenseifner executing final handler ID = {}\n " , id);
169
+
171
170
parent_proxy_[this_node_].template invoke <finalHandler>(state.val_ );
172
171
state.completed_ = true ;
173
172
}
@@ -176,7 +175,6 @@ template <
176
175
typename DataT, template <typename Arg> class Op , typename ObjT,
177
176
auto finalHandler>
178
177
void Rabenseifner<DataT, Op, ObjT, finalHandler>::allreduce(size_t id) {
179
- vt_debug_print (terse, allreduce, " Rabenseifner allreduce is_part_of_adjustment_group_ = {}\n " , is_part_of_adjustment_group_);
180
178
if (is_part_of_adjustment_group_) {
181
179
adjustForPowerOfTwo (id);
182
180
} else {
@@ -193,7 +191,7 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::adjustForPowerOfTwo(size_t id)
193
191
auto const partner = is_even_ ? this_node_ + 1 : this_node_ - 1 ;
194
192
195
193
vt_debug_print (
196
- terse, allreduce, " Rabenseifner::adjustForPowerOfTwo: To Node {} ID = {}\n " , partner, id
194
+ terse, allreduce, " Rabenseifner AdjustInitial ( To {}): ID = {}\n " , partner, id
197
195
);
198
196
199
197
if (is_even_) {
@@ -223,15 +221,23 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::adjustForPowerOfTwoRightHalf(
223
221
224
222
auto & state = states_[msg->id_ ];
225
223
226
- if (not state.initialized_ ){
227
- initializeState (msg->id_ );
224
+ if (state.val_ .empty ()) {
225
+ if (not state.initialized_ ) {
226
+ vt_debug_print (
227
+ verbose, allreduce,
228
+ " Rabenseifner AdjustRightHalf (From {}): State not initialized ID {}!\n " ,
229
+ theContext ()->getFromNodeCurrentTask (), msg->id_
230
+ );
231
+
232
+ initializeState (msg->id_ );
233
+ }
228
234
state.right_adjust_message_ = promoteMsg (msg);
229
235
230
236
return ;
231
237
}
232
238
233
239
vt_debug_print (
234
- terse, allreduce, " Rabenseifner::adjustForPowerOfTwoRightHalf: From Node {} ID = {}\n " ,
240
+ terse, allreduce, " Rabenseifner AdjustRightHalf ( From {}): ID = {}\n " ,
235
241
theContext ()->getFromNodeCurrentTask (), msg->id_
236
242
);
237
243
@@ -252,15 +258,22 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::adjustForPowerOfTwoLeftHalf(
252
258
AllreduceRbnRawMsg<Scalar>* msg) {
253
259
254
260
auto & state = states_[msg->id_ ];
255
- if (not state.initialized_ ){
256
- initializeState (msg->id_ );
261
+ if (state.val_ .empty ()) {
262
+ if (not state.initialized_ ) {
263
+ vt_debug_print (
264
+ verbose, allreduce,
265
+ " Rabenseifner AdjustLeftHalf (From {}): State not initialized ID {}!\n " ,
266
+ theContext ()->getFromNodeCurrentTask (), msg->id_ );
267
+
268
+ initializeState (msg->id_ );
269
+ }
257
270
state.left_adjust_message_ = promoteMsg (msg);
258
271
259
272
return ;
260
273
}
261
274
262
275
vt_debug_print (
263
- terse, allreduce, " Rabenseifner::adjustForPowerOfTwoLeftHalf: From Node {} ID = {}\n " ,
276
+ terse, allreduce, " Rabenseifner AdjustLeftHalf ( From {}): ID = {}\n " ,
264
277
theContext ()->getFromNodeCurrentTask (), msg->id_
265
278
);
266
279
@@ -276,7 +289,7 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::adjustForPowerOfTwoFinalPart(
276
289
AllreduceRbnRawMsg<Scalar>* msg) {
277
290
278
291
vt_debug_print (
279
- terse, allreduce, " Rabenseifner::adjustForPowerOfTwoFinalPart: From Node {} ID = {}\n " ,
292
+ terse, allreduce, " Rabenseifner AdjustFinal ( From {}): ID = {}\n " ,
280
293
theContext ()->getFromNodeCurrentTask (), msg->id_
281
294
);
282
295
@@ -295,7 +308,7 @@ template <
295
308
typename DataT, template <typename Arg> class Op , typename ObjT, auto finalHandler
296
309
>
297
310
bool Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterAllMessagesReceived(size_t id) {
298
- auto & state = states_.at (id);
311
+ auto const & state = states_.at (id);
299
312
300
313
return std::all_of (
301
314
state.scatter_steps_recv_ .cbegin (), state.scatter_steps_recv_ .cbegin () + state.scatter_step_ ,
@@ -306,15 +319,15 @@ template <
306
319
typename DataT, template <typename Arg> class Op , typename ObjT, auto finalHandler
307
320
>
308
321
bool Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterIsDone(size_t id) {
309
- auto & state = states_.at (id);
322
+ auto const & state = states_.at (id);
310
323
return (state.scatter_step_ == num_steps_) and (state.scatter_num_recv_ == num_steps_);
311
324
}
312
325
313
326
template <
314
327
typename DataT, template <typename Arg> class Op , typename ObjT, auto finalHandler
315
328
>
316
329
bool Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterIsReady(size_t id) {
317
- auto & state = states_.at (id);
330
+ auto const & state = states_.at (id);
318
331
return ((is_part_of_adjustment_group_ and state.finished_adjustment_part_ ) and
319
332
state.scatter_step_ == 0 ) or
320
333
((state.scatter_mask_ < nprocs_pof2_) and scatterAllMessagesReceived (id));
@@ -326,12 +339,20 @@ template <
326
339
void Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterTryReduce(
327
340
size_t id, int32_t step) {
328
341
auto & state = states_.at (id);
329
- if (
330
- (step < state.scatter_step_ ) and not state.scatter_steps_reduced_ [step] and
342
+
343
+ auto do_reduce = (step < state.scatter_step_ ) and
344
+ not state.scatter_steps_reduced_ [step] and
331
345
state.scatter_steps_recv_ [step] and
332
- std::all_of (
333
- state.scatter_steps_reduced_ .cbegin (), state.scatter_steps_reduced_ .cbegin () + step,
334
- [](auto const val) { return val; })) {
346
+ std::all_of (state.scatter_steps_reduced_ .cbegin (),
347
+ state.scatter_steps_reduced_ .cbegin () + step,
348
+ [](auto const val) { return val; });
349
+
350
+ vt_debug_print (
351
+ verbose, allreduce, " Rabenseifner ScatterTryReduce (Step = {} ID = {}): {}\n " ,
352
+ step, id, do_reduce
353
+ );
354
+
355
+ if (do_reduce) {
335
356
auto & in_msg = state.scatter_messages_ .at (step);
336
357
auto & in_val = in_msg->val_ ;
337
358
for (uint32_t i = 0 ; i < in_msg->size_ ; i++) {
@@ -356,15 +377,16 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterReduceIter(size_t id) {
356
377
357
378
vt_debug_print (
358
379
terse, allreduce,
359
- " Rabenseifner Scatter (Send step {}): To Node {} starting with idx = {} and "
380
+ " Rabenseifner Scatter (Send step {} to {}): Starting with idx = {} and "
360
381
" count "
361
382
" {} ID = {}\n " ,
362
383
state.scatter_step_ , dest, state.s_index_ [state.scatter_step_ ],
363
384
state.s_count_ [state.scatter_step_ ], id
364
385
);
365
386
366
387
proxy_[dest].template send <&Rabenseifner::scatterReduceIterHandler>(
367
- state.val_ .data () + state.s_index_ [state.scatter_step_ ], state.s_count_ [state.scatter_step_ ], id, state.scatter_step_
388
+ state.val_ .data () + state.s_index_ [state.scatter_step_ ],
389
+ state.s_count_ [state.scatter_step_ ], id, state.scatter_step_
368
390
);
369
391
370
392
state.scatter_mask_ <<= 1 ;
@@ -387,15 +409,35 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterReduceIterHandler(
387
409
AllreduceRbnRawMsg<Scalar>* msg) {
388
410
auto & state = states_[msg->id_ ];
389
411
390
- if (not state.initialized_ ){
391
- initializeState (msg->id_ );
412
+ if (state.val_ .empty ()) {
413
+ if (not state.initialized_ ) {
414
+ vt_debug_print (
415
+ verbose, allreduce,
416
+ " Rabenseifner Scatter (Recv step {} from {}): State not initialized "
417
+ " for ID = "
418
+ " {}!\n " ,
419
+ msg->step_ , theContext ()->getFromNodeCurrentTask (), msg->id_ );
420
+ initializeState (msg->id_ );
421
+ }
422
+
392
423
state.scatter_messages_ [msg->step_ ] = promoteMsg (msg);
393
424
state.scatter_steps_recv_ [msg->step_ ] = true ;
394
425
state.scatter_num_recv_ ++;
395
426
396
427
return ;
397
428
}
398
429
430
+ vt_debug_print (
431
+ terse, allreduce,
432
+ " Rabenseifner Scatter (Recv step {} from {}): initialized = {} "
433
+ " scatter_mask_= {} nprocs_pof2_ = {}: scatterAllMessagesReceived() = {} "
434
+ " state.finished_adjustment_part_ = {}"
435
+ " idx = {} ID = {}\n " ,
436
+ msg->step_ , theContext ()->getFromNodeCurrentTask (), state.initialized_ ,
437
+ state.scatter_mask_ , nprocs_pof2_, scatterAllMessagesReceived (msg->id_ ),
438
+ state.finished_adjustment_part_ , state.r_index_ [msg->step_ ], msg->id_
439
+ );
440
+
399
441
state.scatter_messages_ [msg->step_ ] = promoteMsg (msg);
400
442
state.scatter_steps_recv_ [msg->step_ ] = true ;
401
443
state.scatter_num_recv_ ++;
@@ -406,14 +448,6 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterReduceIterHandler(
406
448
407
449
scatterTryReduce (msg->id_ , msg->step_ );
408
450
409
- vt_debug_print (
410
- terse, allreduce,
411
- " Rabenseifner Scatter (Recv step {}): scatter_mask_= {} nprocs_pof2_ = {}: "
412
- " idx = {} from {} ID = {}\n " ,
413
- msg->step_ , state.scatter_mask_ , nprocs_pof2_, state.r_index_ [msg->step_ ],
414
- theContext ()->getFromNodeCurrentTask (), msg->id_
415
- );
416
-
417
451
if ((state.scatter_mask_ < nprocs_pof2_) and scatterAllMessagesReceived (msg->id_ )) {
418
452
scatterReduceIter (msg->id_ );
419
453
} else if (scatterIsDone (msg->id_ )) {
@@ -516,9 +550,9 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::gatherIterHandler(
516
550
AllreduceRbnRawMsg<Scalar>* msg) {
517
551
auto & state = states_.at (msg->id_ );
518
552
vt_debug_print (
519
- terse, allreduce, " Rabenseifner Gather (step {}): Received idx = {} from {} ID = {}\n " ,
520
- msg->step_ , state.s_index_ [msg->step_ ],
521
- theContext ()-> getFromNodeCurrentTask (), msg->id_
553
+ terse, allreduce, " Rabenseifner Gather (Recv step {} from {} ): idx = {} ID = {}\n " ,
554
+ msg->step_ , theContext ()-> getFromNodeCurrentTask (), state.s_index_ [msg->step_ ],
555
+ msg->id_
522
556
);
523
557
524
558
state.gather_messages_ [msg->step_ ] = promoteMsg (msg);
0 commit comments