|
41 | 41 | //@HEADER
|
42 | 42 | */
|
43 | 43 |
|
| 44 | +#include "vt/configs/types/types_sentinels.h" |
44 | 45 | #if !defined INCLUDED_VT_OBJGROUP_MANAGER_IMPL_H
|
45 | 46 | #define INCLUDED_VT_OBJGROUP_MANAGER_IMPL_H
|
46 | 47 |
|
|
58 | 59 | #include "vt/messaging/active.h"
|
59 | 60 | #include "vt/elm/elm_id_bits.h"
|
60 | 61 | #include "vt/messaging/message/smart_ptr.h"
|
| 62 | +#include "vt/collective/reduce/allreduce/rabenseifner.h" |
| 63 | +#include "vt/collective/reduce/allreduce/recursive_doubling.h" |
61 | 64 | #include <utility>
|
| 65 | +#include <array> |
62 | 66 |
|
63 | 67 | #include <memory>
|
64 | 68 |
|
@@ -264,57 +268,70 @@ ObjGroupManager::PendingSendType ObjGroupManager::broadcast(MsgSharedPtr<MsgT> m
|
264 | 268 | return objgroup::broadcast(msg,han);
|
265 | 269 | }
|
266 | 270 |
|
| 271 | + |
| 272 | +// Helper trait to detect if a type is a specialization of a given variadic template |
| 273 | +template <template <typename...> class Template, typename T> |
| 274 | +struct is_specialization_of : std::false_type {}; |
| 275 | + |
| 276 | +template <template <typename...> class Template, typename... Args> |
| 277 | +struct is_specialization_of<Template, Template<Args...>> : std::true_type {}; |
| 278 | + |
| 279 | +// Specialized trait for std::array |
| 280 | +template <typename T> |
| 281 | +struct is_std_array : std::false_type {}; |
| 282 | + |
| 283 | +template <typename T, std::size_t N> |
| 284 | +struct is_std_array<std::array<T, N>> : std::true_type {}; |
| 285 | + |
| 286 | +// Trait to detect if a type is a standard container (std::vector or std::array in this case) |
| 287 | +template <typename T> |
| 288 | +struct is_std_container : std::integral_constant<bool, |
| 289 | + is_specialization_of<std::vector, T>::value || is_std_array<T>::value> {}; |
| 290 | + |
| 291 | +template < |
| 292 | + typename Reducer, auto f, typename ObjT, template <typename Arg> class Op, typename DataT> |
| 293 | +ObjGroupManager::PendingSendType ObjGroupManager::allreduce( |
| 294 | + ProxyType<ObjT> proxy, const DataT& data) { |
| 295 | + return PendingSendType{ |
| 296 | + theTerm()->getEpoch(), [=] { |
| 297 | + auto const this_node = vt::theContext()->getNode(); |
| 298 | + auto const num_nodes = theContext()->getNumNodes(); |
| 299 | + |
| 300 | + proxy::Proxy<Reducer> grp_proxy = {}; |
| 301 | + |
| 302 | + if (reducers_.find(proxy.getProxy()) != reducers_.end()) { |
| 303 | + auto* obj = reinterpret_cast<Reducer*>( |
| 304 | + objs_[reducers_[proxy.getProxy()]]->getPtr() |
| 305 | + ); |
| 306 | + obj->initialize(data); |
| 307 | + grp_proxy = obj->proxy_; |
| 308 | + } else { |
| 309 | + grp_proxy = vt::theObjGroup()->makeCollective<Reducer>( |
| 310 | + "allreduce_rabenseifner", proxy, num_nodes, data); |
| 311 | + grp_proxy[this_node].get()->proxy_ = grp_proxy; |
| 312 | + } |
| 313 | + |
| 314 | + grp_proxy[this_node].template invoke<&Reducer::allreduce>(); |
| 315 | + }}; |
| 316 | +} |
| 317 | + |
267 | 318 | template <
|
268 | 319 | auto f, typename ObjT, template <typename Arg> class Op, typename DataT>
|
269 | 320 | ObjGroupManager::PendingSendType
|
270 | 321 | ObjGroupManager::allreduce(ProxyType<ObjT> proxy, const DataT& data) {
|
271 |
| - // check payload size and choose appropriate algorithm |
272 |
| - |
273 |
| - auto const this_node = vt::theContext()->getNode(); |
274 |
| - auto const num_nodes = theContext()->getNumNodes(); |
275 |
| - |
276 |
| - if (num_nodes < 2) { |
| 322 | + if (theContext()->getNumNodes() < 2) { |
277 | 323 | return PendingSendType{nullptr};
|
278 | 324 | }
|
279 | 325 |
|
280 |
| - // using Reducer = collective::reduce::allreduce::Rabenseifner<DataT>; |
281 |
| - // using Reducer = collective::reduce::allreduce::DistanceDoubling<DataT, Op, ObjT, f>; |
282 |
| - |
283 |
| - return PendingSendType{theTerm()->getEpoch(), [=] { |
284 |
| - // auto grp_proxy = |
285 |
| - // vt::theObjGroup()->makeCollective<Reducer>("allreduce_rabenseifner"); |
286 |
| - // if constexpr (std::is_same_v< |
287 |
| - // Reducer, |
288 |
| - // collective::reduce::allreduce::DistanceDoubling<DataT, Op, ObjT, f>>) { |
289 |
| - // grp_proxy[this_node].template invoke<&Reducer::initialize>( |
290 |
| - // data, grp_proxy, proxy, num_nodes); |
291 |
| - |
292 |
| - // grp_proxy[this_node].template invoke<&Reducer::partOne>(); |
293 |
| - |
294 |
| - // } else if constexpr (std::is_same_v< |
295 |
| - // Reducer, |
296 |
| - // collective::reduce::allreduce::Rabenseifner< |
297 |
| - // DataT, Op, ObjT, f>>) { |
298 |
| - // grp_proxy[this_node].template invoke<&Reducer::initialize>( |
299 |
| - // data, grp_proxy, num_nodes); |
300 |
| - |
301 |
| - // if (grp_proxy.get()->nprocs_rem_) { |
302 |
| - // vt::runInEpochCollective( |
303 |
| - // [=] { grp_proxy[this_node].template invoke<&Reducer::partOne>(); }); |
304 |
| - // } |
305 |
| - |
306 |
| - // vt::runInEpochCollective( |
307 |
| - // [=] { grp_proxy[this_node].template invoke<&Reducer::partTwo>(); }); |
308 |
| - |
309 |
| - // vt::runInEpochCollective( |
310 |
| - // [=] { grp_proxy[this_node].template invoke<&Reducer::partThree>(); }); |
311 |
| - |
312 |
| - // if (grp_proxy.get()->nprocs_rem_) { |
313 |
| - // vt::runInEpochCollective( |
314 |
| - // [=] { grp_proxy[this_node].template invoke<&Reducer::partFour>(); }); |
315 |
| - // } |
316 |
| - // } |
317 |
| - }}; |
| 326 | + if constexpr (is_std_container<DataT>::value) { |
| 327 | + using Reducer = |
| 328 | + vt::collective::reduce::allreduce::Rabenseifner<DataT, Op, ObjT, f>; |
| 329 | + return allreduce<Reducer, f, ObjT, Op>(proxy, data); |
| 330 | + } else { |
| 331 | + using Reducer = |
| 332 | + vt::collective::reduce::allreduce::DistanceDoubling<DataT, Op, ObjT, f>; |
| 333 | + return allreduce<Reducer, f, ObjT, Op>(proxy, data); |
| 334 | + } |
318 | 335 | }
|
319 | 336 |
|
320 | 337 | template <typename ObjT, typename MsgT, ActiveTypedFnType<MsgT> *f>
|
|
0 commit comments