Skip to content

Commit f01bad5

Browse files
committed
#2240: Add unit tests for new allreduce and cleanup code
1 parent dd33fd0 commit f01bad5

File tree

7 files changed

+178
-61
lines changed

7 files changed

+178
-61
lines changed

src/vt/collective/reduce/allreduce/rabenseifner.h

+26-11
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,16 @@
4141
//@HEADER
4242
*/
4343

44+
4445
#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H
4546
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H
4647

4748
#include "vt/config.h"
4849
#include "vt/context/context.h"
4950
#include "vt/messaging/message/message.h"
5051
#include "vt/objgroup/proxy/proxy_objgroup.h"
52+
#include "vt/registry/auto/auto_registry.h"
53+
#include "vt/pipe/pipe_manager.h"
5154

5255
#include <tuple>
5356
#include <cstdint>
@@ -95,7 +98,6 @@ struct Rabenseifner {
9598
vt::objgroup::proxy::Proxy<ObjT> parentProxy, NodeType num_nodes,
9699
Args&&... args)
97100
: parent_proxy_(parentProxy),
98-
val_(std::forward<Args>(args)...),
99101
num_nodes_(num_nodes),
100102
this_node_(vt::theContext()->getNode()),
101103
is_even_(this_node_ % 2 == 0),
@@ -104,7 +106,15 @@ struct Rabenseifner {
104106
nprocs_rem_(num_nodes_ - nprocs_pof2_),
105107
gather_step_(num_steps_ - 1),
106108
gather_mask_(nprocs_pof2_ >> 1),
107-
finished_adjustment_part_(nprocs_rem_ == 0) {
109+
finished_adjustment_part_(nprocs_rem_ == 0)
110+
{
111+
initialize(std::forward<Args>(args)...);
112+
}
113+
114+
template <typename... Args>
115+
void initialize(Args&&... args) {
116+
val_ = DataT(std::forward<Args>(args)...);
117+
108118
is_part_of_adjustment_group_ = this_node_ < (2 * nprocs_rem_);
109119
if (is_part_of_adjustment_group_) {
110120
if (is_even_) {
@@ -156,6 +166,13 @@ struct Rabenseifner {
156166
scatter_steps_recv_.resize(num_steps_, false);
157167
}
158168

169+
void executeFinalHan() {
170+
171+
// theCB()->makeSend<finalHandler>(parent_proxy_[this_node_]).sendTuple(std::make_tuple(val_));
172+
parent_proxy_[this_node_].template invoke<finalHandler>(val_);
173+
completed_ = true;
174+
}
175+
159176
void allreduce() {
160177
if (nprocs_rem_) {
161178
adjustForPowerOfTwo();
@@ -181,7 +198,7 @@ struct Rabenseifner {
181198
}
182199

183200
void adjustForPowerOfTwoRightHalf(AllreduceRbnMsg<DataT>* msg) {
184-
for (int i = 0; i < msg->val_.size(); i++) {
201+
for (uint32_t i = 0; i < msg->val_.size(); i++) {
185202
val_[(val_.size() / 2) + i] += msg->val_[i];
186203
}
187204

@@ -192,13 +209,13 @@ struct Rabenseifner {
192209
}
193210

194211
void adjustForPowerOfTwoLeftHalf(AllreduceRbnMsg<DataT>* msg) {
195-
for (int i = 0; i < msg->val_.size(); i++) {
212+
for (uint32_t i = 0; i < msg->val_.size(); i++) {
196213
val_[i] += msg->val_[i];
197214
}
198215
}
199216

200217
void adjustForPowerOfTwoFinalPart(AllreduceRbnMsg<DataT>* msg) {
201-
for (int i = 0; i < msg->val_.size(); i++) {
218+
for (uint32_t i = 0; i < msg->val_.size(); i++) {
202219
val_[(val_.size() / 2) + i] = msg->val_[i];
203220
}
204221

@@ -243,7 +260,7 @@ struct Rabenseifner {
243260
[](const auto val) { return val; })) {
244261
auto& in_msg = scatter_messages_.at(step);
245262
auto& in_val = in_msg->val_;
246-
for (int i = 0; i < in_val.size(); i++) {
263+
for (uint32_t i = 0; i < in_val.size(); i++) {
247264
Op<typename DataT::value_type>()(
248265
val_[r_index_[in_msg->step_] + i], in_val[i]);
249266
}
@@ -339,7 +356,7 @@ struct Rabenseifner {
339356
if (doRed) {
340357
auto& in_msg = gather_messages_.at(step);
341358
auto& in_val = in_msg->val_;
342-
for (int i = 0; i < in_val.size(); i++) {
359+
for (uint32_t i = 0; i < in_val.size(); i++) {
343360
val_[s_index_[in_msg->step_] + i] = in_val[i];
344361
}
345362

@@ -417,8 +434,7 @@ struct Rabenseifner {
417434
sendToExcludedNodes();
418435
}
419436

420-
parent_proxy_[this_node_].template invoke<finalHandler>(val_);
421-
completed_ = true;
437+
executeFinalHan();
422438
}
423439

424440
void sendToExcludedNodes() {
@@ -435,8 +451,7 @@ struct Rabenseifner {
435451
void sendToExcludedNodesHandler(AllreduceRbnMsg<DataT>* msg) {
436452
val_ = msg->val_;
437453

438-
parent_proxy_[this_node_].template invoke<finalHandler>(val_);
439-
completed_ = true;
454+
executeFinalHan();
440455
}
441456

442457
vt::objgroup::proxy::Proxy<Rabenseifner> proxy_ = {};

src/vt/collective/reduce/allreduce/recursive_doubling.h

+10-4
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,19 @@ struct DistanceDoubling {
9797
vt::objgroup::proxy::Proxy<ObjT> parentProxy, NodeType num_nodes,
9898
Args&&... args)
9999
: parent_proxy_(parentProxy),
100-
val_(std::forward<Args>(args)...),
101100
num_nodes_(num_nodes),
102101
this_node_(vt::theContext()->getNode()),
103102
is_even_(this_node_ % 2 == 0),
104103
num_steps_(static_cast<int32_t>(log2(num_nodes_))),
105104
nprocs_pof2_(1 << num_steps_),
106105
nprocs_rem_(num_nodes_ - nprocs_pof2_),
107106
finished_adjustment_part_(nprocs_rem_ == 0) {
107+
initialize(std::forward<Args>(args)...);
108+
}
109+
110+
template <typename... Args>
111+
void initialize(Args&&... args) {
112+
val_ = DataT(std::forward<Args>(args)...);
108113
is_part_of_adjustment_group_ = this_node_ < (2 * nprocs_rem_);
109114
if (is_part_of_adjustment_group_) {
110115
if (is_even_) {
@@ -168,8 +173,8 @@ struct DistanceDoubling {
168173
[](const auto val) { return val; });
169174
}
170175
bool isReady() {
171-
return (is_part_of_adjustment_group_ and finished_adjustment_part_) and
172-
step_ == 0 or
176+
return ((is_part_of_adjustment_group_ and finished_adjustment_part_) and
177+
step_ == 0) or
173178
allMessagesReceived();
174179
}
175180

@@ -279,8 +284,9 @@ struct DistanceDoubling {
279284
vt::objgroup::proxy::Proxy<ObjT> parent_proxy_ = {};
280285

281286
DataT val_ = {};
282-
NodeType this_node_ = {};
283287
NodeType num_nodes_ = {};
288+
NodeType this_node_ = {};
289+
284290
bool is_even_ = false;
285291
int32_t num_steps_ = {};
286292
int32_t nprocs_pof2_ = {};

src/vt/objgroup/manager.h

+6
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
//@HEADER
4242
*/
4343

44+
#include "vt/configs/types/types_type.h"
4445
#if !defined INCLUDED_VT_OBJGROUP_MANAGER_H
4546
#define INCLUDED_VT_OBJGROUP_MANAGER_H
4647

@@ -291,6 +292,9 @@ struct ObjGroupManager : runtime::component::Component<ObjGroupManager> {
291292
ProxyType<ObjT> proxy, std::string const& name, std::string const& parent = ""
292293
);
293294

295+
template <typename Reducer, auto f, typename ObjT, template <typename Arg> class Op, typename DataT>
296+
ObjGroupManager::PendingSendType allreduce(ProxyType<ObjT> proxy, const DataT& data);
297+
294298
template <auto f, typename ObjT, template <typename Arg> class Op, typename DataT>
295299
ObjGroupManager::PendingSendType allreduce(ProxyType<ObjT> proxy, const DataT& data);
296300

@@ -504,6 +508,8 @@ ObjGroupManager::PendingSendType allreduce(ProxyType<ObjT> proxy, const DataT& d
504508
std::unordered_map<ObjGroupProxyType, std::vector<ActionType>> pending_;
505509
/// Map of object groups' labels
506510
std::unordered_map<ObjGroupProxyType, std::string> labels_;
511+
512+
std::unordered_map<ObjGroupProxyType, ObjGroupProxyType> reducers_;
507513
};
508514

509515
}} /* end namespace vt::objgroup */

src/vt/objgroup/manager.impl.h

+61-44
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
//@HEADER
4242
*/
4343

44+
#include "vt/configs/types/types_sentinels.h"
4445
#if !defined INCLUDED_VT_OBJGROUP_MANAGER_IMPL_H
4546
#define INCLUDED_VT_OBJGROUP_MANAGER_IMPL_H
4647

@@ -58,7 +59,10 @@
5859
#include "vt/messaging/active.h"
5960
#include "vt/elm/elm_id_bits.h"
6061
#include "vt/messaging/message/smart_ptr.h"
62+
#include "vt/collective/reduce/allreduce/rabenseifner.h"
63+
#include "vt/collective/reduce/allreduce/recursive_doubling.h"
6164
#include <utility>
65+
#include <array>
6266

6367
#include <memory>
6468

@@ -264,57 +268,70 @@ ObjGroupManager::PendingSendType ObjGroupManager::broadcast(MsgSharedPtr<MsgT> m
264268
return objgroup::broadcast(msg,han);
265269
}
266270

271+
272+
// Helper trait to detect if a type is a specialization of a given variadic template
273+
template <template <typename...> class Template, typename T>
274+
struct is_specialization_of : std::false_type {};
275+
276+
template <template <typename...> class Template, typename... Args>
277+
struct is_specialization_of<Template, Template<Args...>> : std::true_type {};
278+
279+
// Specialized trait for std::array
280+
template <typename T>
281+
struct is_std_array : std::false_type {};
282+
283+
template <typename T, std::size_t N>
284+
struct is_std_array<std::array<T, N>> : std::true_type {};
285+
286+
// Trait to detect if a type is a standard container (std::vector or std::array in this case)
287+
template <typename T>
288+
struct is_std_container : std::integral_constant<bool,
289+
is_specialization_of<std::vector, T>::value || is_std_array<T>::value> {};
290+
291+
template <
292+
typename Reducer, auto f, typename ObjT, template <typename Arg> class Op, typename DataT>
293+
ObjGroupManager::PendingSendType ObjGroupManager::allreduce(
294+
ProxyType<ObjT> proxy, const DataT& data) {
295+
return PendingSendType{
296+
theTerm()->getEpoch(), [=] {
297+
auto const this_node = vt::theContext()->getNode();
298+
auto const num_nodes = theContext()->getNumNodes();
299+
300+
proxy::Proxy<Reducer> grp_proxy = {};
301+
302+
if (reducers_.find(proxy.getProxy()) != reducers_.end()) {
303+
auto* obj = reinterpret_cast<Reducer*>(
304+
objs_[reducers_[proxy.getProxy()]]->getPtr()
305+
);
306+
obj->initialize(data);
307+
grp_proxy = obj->proxy_;
308+
} else {
309+
grp_proxy = vt::theObjGroup()->makeCollective<Reducer>(
310+
"allreduce_rabenseifner", proxy, num_nodes, data);
311+
grp_proxy[this_node].get()->proxy_ = grp_proxy;
312+
}
313+
314+
grp_proxy[this_node].template invoke<&Reducer::allreduce>();
315+
}};
316+
}
317+
267318
template <
268319
auto f, typename ObjT, template <typename Arg> class Op, typename DataT>
269320
ObjGroupManager::PendingSendType
270321
ObjGroupManager::allreduce(ProxyType<ObjT> proxy, const DataT& data) {
271-
// check payload size and choose appropriate algorithm
272-
273-
auto const this_node = vt::theContext()->getNode();
274-
auto const num_nodes = theContext()->getNumNodes();
275-
276-
if (num_nodes < 2) {
322+
if (theContext()->getNumNodes() < 2) {
277323
return PendingSendType{nullptr};
278324
}
279325

280-
// using Reducer = collective::reduce::allreduce::Rabenseifner<DataT>;
281-
// using Reducer = collective::reduce::allreduce::DistanceDoubling<DataT, Op, ObjT, f>;
282-
283-
return PendingSendType{theTerm()->getEpoch(), [=] {
284-
// auto grp_proxy =
285-
// vt::theObjGroup()->makeCollective<Reducer>("allreduce_rabenseifner");
286-
// if constexpr (std::is_same_v<
287-
// Reducer,
288-
// collective::reduce::allreduce::DistanceDoubling<DataT, Op, ObjT, f>>) {
289-
// grp_proxy[this_node].template invoke<&Reducer::initialize>(
290-
// data, grp_proxy, proxy, num_nodes);
291-
292-
// grp_proxy[this_node].template invoke<&Reducer::partOne>();
293-
294-
// } else if constexpr (std::is_same_v<
295-
// Reducer,
296-
// collective::reduce::allreduce::Rabenseifner<
297-
// DataT, Op, ObjT, f>>) {
298-
// grp_proxy[this_node].template invoke<&Reducer::initialize>(
299-
// data, grp_proxy, num_nodes);
300-
301-
// if (grp_proxy.get()->nprocs_rem_) {
302-
// vt::runInEpochCollective(
303-
// [=] { grp_proxy[this_node].template invoke<&Reducer::partOne>(); });
304-
// }
305-
306-
// vt::runInEpochCollective(
307-
// [=] { grp_proxy[this_node].template invoke<&Reducer::partTwo>(); });
308-
309-
// vt::runInEpochCollective(
310-
// [=] { grp_proxy[this_node].template invoke<&Reducer::partThree>(); });
311-
312-
// if (grp_proxy.get()->nprocs_rem_) {
313-
// vt::runInEpochCollective(
314-
// [=] { grp_proxy[this_node].template invoke<&Reducer::partFour>(); });
315-
// }
316-
// }
317-
}};
326+
if constexpr (is_std_container<DataT>::value) {
327+
using Reducer =
328+
vt::collective::reduce::allreduce::Rabenseifner<DataT, Op, ObjT, f>;
329+
return allreduce<Reducer, f, ObjT, Op>(proxy, data);
330+
} else {
331+
using Reducer =
332+
vt::collective::reduce::allreduce::DistanceDoubling<DataT, Op, ObjT, f>;
333+
return allreduce<Reducer, f, ObjT, Op>(proxy, data);
334+
}
318335
}
319336

320337
template <typename ObjT, typename MsgT, ActiveTypedFnType<MsgT> *f>

tests/perf/allreduce.cc

+4-2
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,10 @@ VT_PERF_TEST(MyTest, test_allreduce_recursive_doubling) {
184184
auto grp_proxy = vt::theObjGroup()->makeCollective<Reducer>(
185185
"allreduce_recursive_doubling", proxy, num_nodes_, data);
186186
grp_proxy[my_node_].get()->proxy_ = grp_proxy;
187-
vt::runInEpochCollective(
188-
[=] { grp_proxy[my_node_].template invoke<&Reducer::allreduce>(); });
187+
188+
theCollective()->barrier();
189+
StartTimer(proxy[theContext()->getNode()].get()->timer_name_);
190+
grp_proxy[my_node_].template invoke<&Reducer::allreduce>();
189191
}
190192

191193
VT_PERF_TEST_MAIN()

tests/unit/objgroup/test_objgroup.cc

+43
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,49 @@ TEST_F(TestObjGroup, test_proxy_reduce) {
256256
}
257257
}
258258

259+
TEST_F(TestObjGroup, test_proxy_allreduce) {
260+
using namespace vt::collective;
261+
262+
auto const my_node = vt::theContext()->getNode();
263+
264+
TestObjGroup::total_verify_expected_ = 0;
265+
auto proxy = vt::theObjGroup()->makeCollective<MyObjA>("test_proxy_reduce");
266+
267+
vt::theCollective()->barrier();
268+
269+
runInEpochCollective(
270+
[&] { proxy.allreduce_h<&MyObjA::verifyAllred<1>, PlusOp>(my_node); }
271+
);
272+
273+
EXPECT_EQ(MyObjA::total_verify_expected_, 1);
274+
275+
runInEpochCollective(
276+
[&] { proxy.allreduce_h<&MyObjA::verifyAllred<2>, PlusOp>(4); }
277+
);
278+
279+
EXPECT_EQ(MyObjA::total_verify_expected_, 2);
280+
281+
runInEpochCollective(
282+
[&] { proxy.allreduce_h<&MyObjA::verifyAllred<3>, MaxOp>(my_node); }
283+
);
284+
285+
EXPECT_EQ(MyObjA::total_verify_expected_, 3);
286+
287+
runInEpochCollective([&] {
288+
using Reducer =
289+
vt::collective::reduce::allreduce::Rabenseifner<std::vector<int>, PlusOp, MyObjA, &MyObjA::verifyAllredVec>;
290+
std::vector<int> payload(256, my_node);
291+
theObjGroup()->allreduce<Reducer, &MyObjA::verifyAllredVec, MyObjA, PlusOp>(
292+
proxy, payload
293+
);
294+
theObjGroup()->allreduce<Reducer, &MyObjA::verifyAllredVec, MyObjA, PlusOp>(
295+
proxy, payload
296+
);
297+
});
298+
299+
EXPECT_EQ(MyObjA::total_verify_expected_, 5);
300+
}
301+
259302
TEST_F(TestObjGroup, test_proxy_invoke) {
260303
auto const& this_node = theContext()->getNode();
261304

0 commit comments

Comments
 (0)