Skip to content

Commit 57b8cab

Browse files
committed
#2240: Store Reducers by tuple(ProxyType, DataType, OperandType)
1 parent c5232dc commit 57b8cab

File tree

10 files changed

+105
-40
lines changed

10 files changed

+105
-40
lines changed

src/vt/collective/reduce/allreduce/helpers.h

+8
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,18 @@
4343

4444
#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_HELPERS_H
4545
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_HELPERS_H
46+
4647
#include "data_handler.h"
4748
#include "rabenseifner_msg.h"
4849
#include "vt/messaging/message/shared_message.h"
50+
4951
#include <vector>
52+
#include <type_traits>
53+
54+
namespace vt {
55+
template <typename T>
56+
using remove_cvref = std::remove_cv_t<std::remove_reference_t<T>>;
57+
}
5058

5159
namespace vt::collective::reduce::allreduce {
5260

src/vt/collective/reduce/allreduce/rabenseifner.h

+2
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,10 @@ template <
9393
typename DataT, template <typename Arg> class Op, typename ObjT, auto finalHandler
9494
>
9595
struct Rabenseifner {
96+
using Data = DataT;
9697
using DataType = DataHandler<DataT>;
9798
using Scalar = typename DataType::Scalar;
99+
using ReduceOp = Op<Scalar>;
98100
using DataHelperT = DataHelper<Scalar, DataT>;
99101
using StateT = State<Scalar, DataT>;
100102

src/vt/collective/reduce/allreduce/recursive_doubling.h

+2
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,10 @@ template <
116116
typename DataT, template <typename Arg> class Op, typename ObjT,
117117
auto finalHandler>
118118
struct RecursiveDoubling {
119+
using Data = DataT;
119120
using DataType = DataHandler<DataT>;
120121
using Scalar = typename DataHandler<DataT>::Scalar;
122+
using ReduceOp = Op<Scalar>;
121123
/**
122124
* \brief Constructor for RecursiveDoubling class.
123125
*

src/vt/collective/reduce/allreduce/recursive_doubling.impl.h

-9
Original file line numberDiff line numberDiff line change
@@ -354,15 +354,6 @@ void RecursiveDoubling<DataT, Op, ObjT, finalHandler>::finalPart(size_t id) {
354354
parent_proxy_[this_node_].template invoke<finalHandler>(state.val_);
355355

356356
state.completed_ = true;
357-
358-
state.adjust_message_ = nullptr;
359-
state.messages_.clear();
360-
361-
states_.erase(id);
362-
// std::fill(state.messages_.begin(), state.messages_.end(), nullptr);
363-
364-
// state.steps_recv_.assign(num_steps_, false);
365-
// state.steps_reduced_.assign(num_steps_, false);
366357
}
367358

368359
} // namespace vt::collective::reduce::allreduce

src/vt/objgroup/manager.h

+11-3
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,13 @@
5858
#include "vt/messaging/pending_send.h"
5959
#include "vt/elm/elm_id.h"
6060
#include "vt/utils/fntraits/fntraits.h"
61+
#include "vt/utils/hash/hash_tuple.h"
6162

6263
#include <memory>
6364
#include <functional>
6465
#include <unordered_map>
6566
#include <vector>
67+
#include <typeindex>
6668

6769
namespace vt { namespace objgroup {
6870

@@ -91,6 +93,11 @@ struct ObjGroupManager : runtime::component::Component<ObjGroupManager> {
9193
using HolderBaseType = holder::HolderBase;
9294
using HolderBasePtrType = std::unique_ptr<HolderBaseType>;
9395
using PendingSendType = messaging::PendingSend;
96+
using ReduceDataType = std::type_index;
97+
using ReduceOperandType = std::type_index;
98+
using ReducerMapType = std::unordered_map<
99+
std::tuple<ObjGroupProxyType, ReduceDataType, ReduceOperandType>,
100+
ObjGroupProxyType>;
94101

95102
public:
96103
/**
@@ -507,9 +514,10 @@ ObjGroupManager::PendingSendType allreduce(ProxyType<ObjT> proxy, Args&&... data
507514
std::unordered_map<ObjGroupProxyType, std::vector<ActionType>> pending_;
508515
/// Map of object groups' labels
509516
std::unordered_map<ObjGroupProxyType, std::string> labels_;
510-
511-
std::unordered_map<ObjGroupProxyType, ObjGroupProxyType> reducersRD_;
512-
std::unordered_map<ObjGroupProxyType, ObjGroupProxyType> reducersR_;
517+
/// Recursive Doubling reducers
518+
ReducerMapType reducers_recursive_doubling_;
519+
/// Rabenseifner reducers
520+
ReducerMapType reducers_rabenseifner_;
513521
};
514522

515523
}} /* end namespace vt::objgroup */

src/vt/objgroup/manager.impl.h

+26-8
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
#include "vt/collective/reduce/allreduce/rabenseifner.h"
6262
#include "vt/collective/reduce/allreduce/recursive_doubling.h"
6363
#include "vt/collective/reduce/allreduce/type.h"
64+
#include "vt/collective/reduce/allreduce/helpers.h"
6465
#include <utility>
6566
#include <array>
6667

@@ -279,21 +280,36 @@ ObjGroupManager::PendingSendType ObjGroupManager::allreduce(
279280

280281
proxy::Proxy<Reducer> grp_proxy = {};
281282

282-
auto& reducers = Reducer::type_ == ReducerType::Rabenseifner ? reducersR_ : reducersRD_;
283-
if (reducers.find(proxy.getProxy()) != reducers.end()) {
284-
auto* obj = reinterpret_cast<Reducer*>(
285-
objs_.at(reducers.at(proxy.getProxy()))->getPtr()
283+
auto& reducers = Reducer::type_ == ReducerType::Rabenseifner ?
284+
reducers_rabenseifner_ :
285+
reducers_recursive_doubling_;
286+
auto const key = std::make_tuple(
287+
proxy.getProxy(), std::type_index(typeid(typename Reducer::Data)),
288+
std::type_index(typeid(typename Reducer::ReduceOp))
289+
);
290+
if (reducers.find(key) != reducers.end()) {
291+
vt_debug_print(
292+
verbose, allreduce, "Found reducer (type: {}) for proxy {:x}",
293+
TypeToString(Reducer::type_), proxy.getProxy()
286294
);
295+
296+
auto* obj =
297+
reinterpret_cast<Reducer*>(objs_.at(reducers.at(key))->getPtr());
287298
id = obj->generateNewId();
288299
obj->initialize(id, std::forward<Args>(data)...);
289300
grp_proxy = obj->proxy_;
290301
} else {
302+
vt_debug_print(
303+
verbose, allreduce, "Creating reducer (type: {}) for proxy {:x}",
304+
TypeToString(Reducer::type_), proxy.getProxy()
305+
);
306+
291307
grp_proxy = vt::theObjGroup()->makeCollective<Reducer>(
292-
TypeToString(Reducer::type_), proxy,
293-
num_nodes, std::forward<Args>(data)...
308+
TypeToString(Reducer::type_), proxy, num_nodes,
309+
std::forward<Args>(data)...
294310
);
295311
grp_proxy[this_node].get()->proxy_ = grp_proxy;
296-
reducers[proxy.getProxy()] = grp_proxy.getProxy();
312+
reducers[key] = grp_proxy.getProxy();
297313
id = grp_proxy[this_node].get()->id_ - 1;
298314
}
299315

@@ -314,9 +330,10 @@ ObjGroupManager::allreduce(ProxyType<ObjT> proxy, Args&&... data) {
314330
}
315331

316332
auto const payload_size =
317-
collective::reduce::allreduce::DataHandler<DataT>::size(
333+
collective::reduce::allreduce::DataHandler<remove_cvref<DataT>>::size(
318334
std::forward<Args>(data)...
319335
);
336+
320337
if (payload_size < 2048) {
321338
using Reducer =
322339
vt::collective::reduce::allreduce::RecursiveDoubling<DataT, Op, ObjT, f>;
@@ -327,6 +344,7 @@ ObjGroupManager::allreduce(ProxyType<ObjT> proxy, Args&&... data) {
327344
return allreduce<Reducer>(proxy, std::forward<Args>(data)...);
328345
}
329346

347+
// Silence nvcc warning
330348
return PendingSendType{nullptr};
331349
}
332350

src/vt/objgroup/proxy/proxy_objgroup.impl.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
#include "vt/messaging/param_msg.h"
5757
#include "vt/objgroup/proxy/proxy_bits.h"
5858
#include "vt/collective/reduce/get_reduce_stamp.h"
59+
#include "vt/collective/reduce/allreduce/helpers.h"
5960

6061
namespace vt { namespace objgroup { namespace proxy {
6162

@@ -215,8 +216,7 @@ Proxy<ObjT>::allreduce_h(
215216
) const {
216217
auto proxy = Proxy<ObjT>(*this);
217218

218-
// using DataT = std::tuple<std::decay_t<Args>...>;
219-
return theObjGroup()->allreduce<f, ObjT, Op, std::decay_t<Args>...>(
219+
return theObjGroup()->allreduce<f, ObjT, Op, remove_cvref<Args>...>(
220220
proxy, std::forward<Args>(args)...);
221221
}
222222

src/vt/utils/hash/hash_tuple.h

+35-7
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,44 @@
4747
#include <tuple>
4848

4949
namespace std {
50+
namespace {
5051

51-
template <typename A, typename B>
52-
struct hash<std::tuple<A, B>> {
53-
size_t operator()(std::tuple<A, B> const& in) const {
54-
auto const& v1 = std::hash<A>()(std::get<0>(in));
55-
auto const& v2 = std::hash<B>()(std::get<1>(in));
56-
return v1 ^ v2;
52+
// Code from boost
53+
// Reciprocal of the golden ratio helps spread entropy
54+
// and handles duplicates.
55+
// See Mike Seymour in magic-numbers-in-boosthash-combine:
56+
// http://stackoverflow.com/questions/4948780
57+
58+
template <class T>
59+
inline void hash_combine(std::size_t& seed, T const& v) {
60+
seed ^= std::hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
61+
}
62+
63+
// Recursive template code derived from Matthieu M.
64+
template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
65+
struct HashValueImpl {
66+
static void apply(size_t& seed, Tuple const& tuple) {
67+
HashValueImpl<Tuple, Index - 1>::apply(seed, tuple);
68+
hash_combine(seed, std::get<Index>(tuple));
5769
}
5870
};
5971

60-
}
72+
template <class Tuple>
73+
struct HashValueImpl<Tuple, 0> {
74+
static void apply(size_t& seed, Tuple const& tuple) {
75+
hash_combine(seed, std::get<0>(tuple));
76+
}
77+
};
78+
} // namespace
79+
80+
template <typename... TT>
81+
struct hash<std::tuple<TT...>> {
82+
size_t operator()(std::tuple<TT...> const& tt) const {
83+
size_t seed = 0;
84+
HashValueImpl<std::tuple<TT...>>::apply(seed, tt);
85+
return seed;
86+
}
87+
};
88+
} // namespace std
6189

6290
#endif /*INCLUDED_VT_UTILS_HASH_HASH_TUPLE_H*/

tests/unit/objgroup/test_objgroup.cc

+11-7
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343

4444
#include "test_objgroup_common.h"
4545
#include "test_helpers.h"
46+
#include "vt/collective/reduce/allreduce/rabenseifner.h"
47+
#include "vt/configs/types/types_type.h"
4648
#include "vt/objgroup/manager.h"
4749

4850
#include <typeinfo>
@@ -266,7 +268,7 @@ TEST_F(TestObjGroup, test_proxy_allreduce) {
266268
auto const my_node = vt::theContext()->getNode();
267269

268270
TestObjGroup::total_verify_expected_ = 0;
269-
auto proxy = vt::theObjGroup()->makeCollective<MyObjA>("test_proxy_reduce");
271+
auto proxy = vt::theObjGroup()->makeCollective<MyObjA>("test_proxy_allreduce");
270272

271273
vt::theCollective()->barrier();
272274

@@ -289,7 +291,7 @@ TEST_F(TestObjGroup, test_proxy_allreduce) {
289291
EXPECT_EQ(MyObjA::total_verify_expected_, 3);
290292
runInEpochCollective([&] {
291293
using Reducer = vt::collective::reduce::allreduce::RecursiveDoubling<
292-
std::vector<int>, PlusOp, MyObjA, &MyObjA::verifyAllredVec
294+
std::vector<int>, PlusOp, MyObjA, &MyObjA::verifyAllredVec<int, 256>
293295
>;
294296
std::vector<int> payload(256, my_node);
295297
theObjGroup()->allreduce<Reducer>(proxy, payload);
@@ -299,18 +301,20 @@ TEST_F(TestObjGroup, test_proxy_allreduce) {
299301

300302
runInEpochCollective([&] {
301303
using Reducer = vt::collective::reduce::allreduce::Rabenseifner<
302-
std::vector<int>, PlusOp, MyObjA, &MyObjA::verifyAllredVec
304+
NodeType, PlusOp, MyObjA, &MyObjA::verifyAllred<1>
303305
>;
304-
std::vector<int> payload(256, my_node);
305-
theObjGroup()->allreduce<Reducer>(proxy, payload);
306-
theObjGroup()->allreduce<Reducer>(proxy, payload);
306+
std::vector<int> payload(2048, my_node);
307+
theObjGroup()->allreduce<Reducer>(proxy, my_node);
308+
309+
std::vector<short> payload_large(2048 * 2, my_node);
310+
theObjGroup()->allreduce<Reducer>(proxy, my_node);
307311
});
308312

309313
EXPECT_EQ(MyObjA::total_verify_expected_, 6);
310314

311315
runInEpochCollective([&] {
312316
using Reducer = vt::collective::reduce::allreduce::Rabenseifner<
313-
VectorPayload, PlusOp, MyObjA, &MyObjA::verifyAllredVecPayload>;
317+
VectorPayload, PlusOp, MyObjA, &MyObjA::verifyAllredVecPayload<VectorPayload, 256>>;
314318
std::vector<int> payload(256, my_node);
315319
VectorPayload data{payload};
316320
theObjGroup()->allreduce<Reducer>(proxy, data);

tests/unit/objgroup/test_objgroup_common.h

+8-4
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,12 @@ struct MyObjA {
131131
total_verify_expected_++;
132132
}
133133

134-
void verifyAllredVec(std::vector<int> vec) {
134+
template<typename Scalar, int32_t size>
135+
void verifyAllredVec(std::vector<Scalar> vec) {
135136
auto final_size = vec.size();
136-
EXPECT_EQ(final_size, 256);
137+
EXPECT_EQ(final_size, size);
137138

138-
auto n = vt::theContext()->getNumNodes();
139+
auto const n = theContext()->getNumNodes();
139140
auto const total_sum = n * (n - 1)/2;
140141
for(auto val : vec){
141142
EXPECT_EQ(val, total_sum);
@@ -144,7 +145,10 @@ struct MyObjA {
144145
total_verify_expected_++;
145146
}
146147

147-
void verifyAllredVecPayload(VectorPayload vec) { verifyAllredVec(vec.vec_); }
148+
template <typename DataT, int32_t size>
149+
void verifyAllredVecPayload(VectorPayload vec) {
150+
verifyAllredVec<typename decltype(DataT::vec_)::value_type, size>(vec.vec_);
151+
}
148152

149153
#if MAGISTRATE_KOKKOS_ENABLED
150154
void verifyAllredView(Kokkos::View<float*, Kokkos::HostSpace> view) {

0 commit comments

Comments
 (0)