Skip to content

Commit 5c4fed6

Browse files
committed
#2240: Fix issues with handlers being executed and payload not being initialized
1 parent a537975 commit 5c4fed6

File tree

2 files changed

+90
-43
lines changed

2 files changed

+90
-43
lines changed

src/vt/collective/reduce/allreduce/rabenseifner.impl.h

+68-34
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ Rabenseifner<DataT, Op, ObjT, finalHandler>::Rabenseifner(
7474
vrt_node_ = this_node_ - nprocs_rem_;
7575
}
7676

77-
vt_debug_print(terse, allreduce, "Rabenseifner constructor\n");
7877
initialize(generateNewId(), std::forward<Args>(data)...);
7978
}
8079

@@ -165,9 +164,9 @@ template <
165164
typename DataT, template <typename Arg> class Op, typename ObjT, auto finalHandler
166165
>
167166
void Rabenseifner<DataT, Op, ObjT, finalHandler>::executeFinalHan(size_t id) {
168-
// theCB()->makeSend<finalHandler>(parent_proxy_[this_node_]).sendTuple(std::make_tuple(val_));
169167
auto& state = states_.at(id);
170168
vt_debug_print(terse, allreduce, "Rabenseifner executing final handler ID = {}\n", id);
169+
171170
parent_proxy_[this_node_].template invoke<finalHandler>(state.val_);
172171
state.completed_ = true;
173172
}
@@ -176,7 +175,6 @@ template <
176175
typename DataT, template <typename Arg> class Op, typename ObjT,
177176
auto finalHandler>
178177
void Rabenseifner<DataT, Op, ObjT, finalHandler>::allreduce(size_t id) {
179-
vt_debug_print(terse, allreduce, "Rabenseifner allreduce is_part_of_adjustment_group_ = {}\n", is_part_of_adjustment_group_);
180178
if (is_part_of_adjustment_group_) {
181179
adjustForPowerOfTwo(id);
182180
} else {
@@ -193,7 +191,7 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::adjustForPowerOfTwo(size_t id)
193191
auto const partner = is_even_ ? this_node_ + 1 : this_node_ - 1;
194192

195193
vt_debug_print(
196-
terse, allreduce, "Rabenseifner::adjustForPowerOfTwo: To Node {} ID = {}\n", partner, id
194+
terse, allreduce, "Rabenseifner AdjustInitial (To {}): ID = {}\n", partner, id
197195
);
198196

199197
if (is_even_) {
@@ -223,15 +221,23 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::adjustForPowerOfTwoRightHalf(
223221

224222
auto& state = states_[msg->id_];
225223

226-
if(not state.initialized_){
227-
initializeState(msg->id_);
224+
if (state.val_.empty()) {
225+
if (not state.initialized_) {
226+
vt_debug_print(
227+
verbose, allreduce,
228+
"Rabenseifner AdjustRightHalf (From {}): State not initialized ID {}!\n",
229+
theContext()->getFromNodeCurrentTask(), msg->id_
230+
);
231+
232+
initializeState(msg->id_);
233+
}
228234
state.right_adjust_message_ = promoteMsg(msg);
229235

230236
return;
231237
}
232238

233239
vt_debug_print(
234-
terse, allreduce, "Rabenseifner::adjustForPowerOfTwoRightHalf: From Node {} ID = {}\n",
240+
terse, allreduce, "Rabenseifner AdjustRightHalf (From {}): ID = {}\n",
235241
theContext()->getFromNodeCurrentTask(), msg->id_
236242
);
237243

@@ -252,15 +258,22 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::adjustForPowerOfTwoLeftHalf(
252258
AllreduceRbnRawMsg<Scalar>* msg) {
253259

254260
auto& state = states_[msg->id_];
255-
if(not state.initialized_){
256-
initializeState(msg->id_);
261+
if (state.val_.empty()) {
262+
if (not state.initialized_) {
263+
vt_debug_print(
264+
verbose, allreduce,
265+
"Rabenseifner AdjustLeftHalf (From {}): State not initialized ID {}!\n",
266+
theContext()->getFromNodeCurrentTask(), msg->id_);
267+
268+
initializeState(msg->id_);
269+
}
257270
state.left_adjust_message_ = promoteMsg(msg);
258271

259272
return;
260273
}
261274

262275
vt_debug_print(
263-
terse, allreduce, "Rabenseifner::adjustForPowerOfTwoLeftHalf: From Node {} ID = {}\n",
276+
terse, allreduce, "Rabenseifner AdjustLeftHalf (From {}): ID = {}\n",
264277
theContext()->getFromNodeCurrentTask(), msg->id_
265278
);
266279

@@ -276,7 +289,7 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::adjustForPowerOfTwoFinalPart(
276289
AllreduceRbnRawMsg<Scalar>* msg) {
277290

278291
vt_debug_print(
279-
terse, allreduce, "Rabenseifner::adjustForPowerOfTwoFinalPart: From Node {} ID = {}\n",
292+
terse, allreduce, "Rabenseifner AdjustFinal (From {}): ID = {}\n",
280293
theContext()->getFromNodeCurrentTask(), msg->id_
281294
);
282295

@@ -295,7 +308,7 @@ template <
295308
typename DataT, template <typename Arg> class Op, typename ObjT, auto finalHandler
296309
>
297310
bool Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterAllMessagesReceived(size_t id) {
298-
auto& state = states_.at(id);
311+
auto const& state = states_.at(id);
299312

300313
return std::all_of(
301314
state.scatter_steps_recv_.cbegin(), state.scatter_steps_recv_.cbegin() + state.scatter_step_,
@@ -306,15 +319,15 @@ template <
306319
typename DataT, template <typename Arg> class Op, typename ObjT, auto finalHandler
307320
>
308321
bool Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterIsDone(size_t id) {
309-
auto& state = states_.at(id);
322+
auto const& state = states_.at(id);
310323
return (state.scatter_step_ == num_steps_) and (state.scatter_num_recv_ == num_steps_);
311324
}
312325

313326
template <
314327
typename DataT, template <typename Arg> class Op, typename ObjT, auto finalHandler
315328
>
316329
bool Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterIsReady(size_t id) {
317-
auto& state = states_.at(id);
330+
auto const& state = states_.at(id);
318331
return ((is_part_of_adjustment_group_ and state.finished_adjustment_part_) and
319332
state.scatter_step_ == 0) or
320333
((state.scatter_mask_ < nprocs_pof2_) and scatterAllMessagesReceived(id));
@@ -326,12 +339,20 @@ template <
326339
void Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterTryReduce(
327340
size_t id, int32_t step) {
328341
auto& state = states_.at(id);
329-
if (
330-
(step < state.scatter_step_) and not state.scatter_steps_reduced_[step] and
342+
343+
auto do_reduce = (step < state.scatter_step_) and
344+
not state.scatter_steps_reduced_[step] and
331345
state.scatter_steps_recv_[step] and
332-
std::all_of(
333-
state.scatter_steps_reduced_.cbegin(), state.scatter_steps_reduced_.cbegin() + step,
334-
[](auto const val) { return val; })) {
346+
std::all_of(state.scatter_steps_reduced_.cbegin(),
347+
state.scatter_steps_reduced_.cbegin() + step,
348+
[](auto const val) { return val; });
349+
350+
vt_debug_print(
351+
verbose, allreduce, "Rabenseifner ScatterTryReduce (Step = {} ID = {}): {}\n",
352+
step, id, do_reduce
353+
);
354+
355+
if (do_reduce) {
335356
auto& in_msg = state.scatter_messages_.at(step);
336357
auto& in_val = in_msg->val_;
337358
for (uint32_t i = 0; i < in_msg->size_; i++) {
@@ -356,15 +377,16 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterReduceIter(size_t id) {
356377

357378
vt_debug_print(
358379
terse, allreduce,
359-
"Rabenseifner Scatter (Send step {}): To Node {} starting with idx = {} and "
380+
"Rabenseifner Scatter (Send step {} to {}): Starting with idx = {} and "
360381
"count "
361382
"{} ID = {}\n",
362383
state.scatter_step_, dest, state.s_index_[state.scatter_step_],
363384
state.s_count_[state.scatter_step_], id
364385
);
365386

366387
proxy_[dest].template send<&Rabenseifner::scatterReduceIterHandler>(
367-
state.val_.data() + state.s_index_[state.scatter_step_], state.s_count_[state.scatter_step_], id, state.scatter_step_
388+
state.val_.data() + state.s_index_[state.scatter_step_],
389+
state.s_count_[state.scatter_step_], id, state.scatter_step_
368390
);
369391

370392
state.scatter_mask_ <<= 1;
@@ -387,15 +409,35 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterReduceIterHandler(
387409
AllreduceRbnRawMsg<Scalar>* msg) {
388410
auto& state = states_[msg->id_];
389411

390-
if(not state.initialized_){
391-
initializeState(msg->id_);
412+
if (state.val_.empty()) {
413+
if (not state.initialized_) {
414+
vt_debug_print(
415+
verbose, allreduce,
416+
"Rabenseifner Scatter (Recv step {} from {}): State not initialized "
417+
"for ID = "
418+
"{}!\n",
419+
msg->step_, theContext()->getFromNodeCurrentTask(), msg->id_);
420+
initializeState(msg->id_);
421+
}
422+
392423
state.scatter_messages_[msg->step_] = promoteMsg(msg);
393424
state.scatter_steps_recv_[msg->step_] = true;
394425
state.scatter_num_recv_++;
395426

396427
return;
397428
}
398429

430+
vt_debug_print(
431+
terse, allreduce,
432+
"Rabenseifner Scatter (Recv step {} from {}): initialized = {} "
433+
"scatter_mask_= {} nprocs_pof2_ = {}: scatterAllMessagesReceived() = {} "
434+
"state.finished_adjustment_part_ = {}"
435+
"idx = {} ID = {}\n",
436+
msg->step_, theContext()->getFromNodeCurrentTask(), state.initialized_,
437+
state.scatter_mask_, nprocs_pof2_, scatterAllMessagesReceived(msg->id_),
438+
state.finished_adjustment_part_, state.r_index_[msg->step_], msg->id_
439+
);
440+
399441
state.scatter_messages_[msg->step_] = promoteMsg(msg);
400442
state.scatter_steps_recv_[msg->step_] = true;
401443
state.scatter_num_recv_++;
@@ -406,14 +448,6 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::scatterReduceIterHandler(
406448

407449
scatterTryReduce(msg->id_, msg->step_);
408450

409-
vt_debug_print(
410-
terse, allreduce,
411-
"Rabenseifner Scatter (Recv step {}): scatter_mask_= {} nprocs_pof2_ = {}: "
412-
"idx = {} from {} ID = {}\n",
413-
msg->step_, state.scatter_mask_, nprocs_pof2_, state.r_index_[msg->step_],
414-
theContext()->getFromNodeCurrentTask(), msg->id_
415-
);
416-
417451
if ((state.scatter_mask_ < nprocs_pof2_) and scatterAllMessagesReceived(msg->id_)) {
418452
scatterReduceIter(msg->id_);
419453
} else if (scatterIsDone(msg->id_)) {
@@ -516,9 +550,9 @@ void Rabenseifner<DataT, Op, ObjT, finalHandler>::gatherIterHandler(
516550
AllreduceRbnRawMsg<Scalar>* msg) {
517551
auto& state = states_.at(msg->id_);
518552
vt_debug_print(
519-
terse, allreduce, "Rabenseifner Gather (step {}): Received idx = {} from {} ID = {}\n",
520-
msg->step_, state.s_index_[msg->step_],
521-
theContext()->getFromNodeCurrentTask(), msg->id_
553+
terse, allreduce, "Rabenseifner Gather (Recv step {} from {}): idx = {} ID = {}\n",
554+
msg->step_, theContext()->getFromNodeCurrentTask(), state.s_index_[msg->step_],
555+
msg->id_
522556
);
523557

524558
state.gather_messages_[msg->step_] = promoteMsg(msg);

src/vt/collective/reduce/allreduce/recursive_doubling.impl.h

+22-9
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ void RecursiveDoubling<DataT, Op, ObjT, finalHandler>::initialize(
9191
state.val_ = DataT{std::forward<Args>(data)...};
9292

9393
vt_debug_print(
94-
terse, allreduce, "RecursiveDoubling Initialize: size {} ID {}\n", DataType::size(state.val_), id
94+
terse, allreduce, "RecursiveDoubling Initialize: size {} ID {}\n",
95+
DataType::size(state.val_), id
9596
);
9697
}
9798

@@ -100,6 +101,9 @@ template <
100101
auto finalHandler>
101102
void RecursiveDoubling<DataT, Op, ObjT, finalHandler>::initializeState(size_t id){
102103
auto& state = states_[id];
104+
105+
vt_debug_print(terse, allreduce, "RecursiveDoubling initializing state for ID = {}\n", id);
106+
103107
state.messages_.resize(num_steps_, nullptr);
104108
state.steps_recv_.resize(num_steps_, false);
105109
state.steps_reduced_.resize(num_steps_, false);
@@ -129,8 +133,8 @@ void RecursiveDoubling<DataT, Op, ObjT, finalHandler>::adjustForPowerOfTwo(size_
129133
auto& state = states_.at(id);
130134
if (is_part_of_adjustment_group_ and not is_even_) {
131135
vt_debug_print(
132-
terse, allreduce, "RecursiveDoubling Part1: Sending to Node {} ID ={} \n", this_node_,
133-
this_node_ - 1, id
136+
terse, allreduce, "RecursiveDoubling AdjustInitial (To {}): ID = {} \n",
137+
this_node_, this_node_ - 1, id
134138
);
135139

136140
proxy_[this_node_ - 1]
@@ -148,8 +152,10 @@ void RecursiveDoubling<DataT, Op, ObjT, finalHandler>::
148152
adjustForPowerOfTwoHandler(AllreduceDblRawMsg<DataT>* msg) {
149153

150154
auto& state = states_[msg->id_];
151-
if(not state.initialized_) {
152-
initializeState(msg->id_);
155+
if (DataType::size(state.val_) == 0) {
156+
if (not state.initialized_) {
157+
initializeState(msg->id_);
158+
}
153159
state.adjust_message_ = promoteMsg(msg);
154160

155161
return;
@@ -240,8 +246,13 @@ void RecursiveDoubling<DataT, Op, ObjT, finalHandler>::tryReduce(size_t id, int3
240246
[](const auto val) { return val; });
241247

242248
vt_debug_print(
243-
terse, allreduce, "RecursiveDoubling Part2 (Reduce step {}): state.step_ = {} state.steps_reduced_[step] = {} state.steps_recv_[step] = {} all_msgs_received = {} ID = {} \n",
244-
step, state.step_, static_cast<bool>(state.steps_reduced_[step]), static_cast<bool>(state.steps_recv_[step]), all_msgs_received, id);
249+
terse, allreduce,
250+
"RecursiveDoubling Part2 (Reduce step {}): state.step_ = {} "
251+
"state.steps_reduced_[step] = {} state.steps_recv_[step] = {} "
252+
"all_msgs_received = {} ID = {} \n",
253+
step, state.step_, static_cast<bool>(state.steps_reduced_[step]),
254+
static_cast<bool>(state.steps_recv_[step]), all_msgs_received, id
255+
);
245256

246257
if (
247258
(step < state.step_) and not state.steps_reduced_[step] and
@@ -259,8 +270,10 @@ void RecursiveDoubling<DataT, Op, ObjT, finalHandler>::reduceIterHandler(
259270
AllreduceDblRawMsg<DataT>* msg) {
260271
auto& state = states_[msg->id_];
261272

262-
if(not state.initialized_){
263-
initializeState(msg->id_);
273+
if (DataType::size(state.val_) == 0) {
274+
if (not state.initialized_) {
275+
initializeState(msg->id_);
276+
}
264277
state.messages_.at(msg->step_) = promoteMsg(msg);
265278
state.steps_recv_[msg->step_] = true;
266279

0 commit comments

Comments
 (0)