From 979c52f8be9c0a40554a8668e69ca576e33a5b19 Mon Sep 17 00:00:00 2001 From: Jonathan Lifflander Date: Mon, 24 Apr 2023 22:07:46 -0700 Subject: [PATCH] #276: pipe: finish implementation across objgroup and collections --- examples/collection/jacobi1d_vt.cc | 21 +-- examples/collection/jacobi2d_vt.cc | 21 +-- examples/collection/reduce_integral.cc | 32 ++-- examples/collection/transpose.cc | 8 +- .../hello_world_collection_reduce.cc | 19 +-- src/vt/collective/reduce/get_reduce_stamp.h | 99 +++++++++++++ .../collective/reduce/operators/default_msg.h | 111 ++++++++++---- .../reduce/operators/default_op.impl.h | 24 ++- src/vt/collective/reduce/reduce.h | 13 +- src/vt/objgroup/proxy/proxy_objgroup.h | 65 ++++++++- src/vt/objgroup/proxy/proxy_objgroup.impl.h | 101 +++++++++++++ src/vt/phase/phase_manager.cc | 14 +- src/vt/phase/phase_manager.h | 7 +- src/vt/pipe/callback/cb_union/cb_raw_base.h | 10 ++ src/vt/pipe/pipe_manager_tl.impl.h | 2 +- src/vt/termination/graph/epoch_graph_reduce.h | 8 +- src/vt/utils/fntraits/fntraits.h | 42 +++--- .../vrt/collection/balance/baselb/baselb.cc | 7 +- src/vt/vrt/collection/balance/baselb/baselb.h | 2 +- .../collection/balance/greedylb/greedylb.cc | 17 +-- .../collection/balance/greedylb/greedylb.h | 2 +- .../balance/hierarchicallb/hierlb.cc | 6 +- .../balance/hierarchicallb/hierlb.h | 2 +- .../balance/hierarchicallb/hierlb_msgs.h | 2 - .../balance/lb_data_restart_reader.cc | 10 +- .../balance/lb_data_restart_reader.h | 4 +- .../balance/lb_invoke/lb_manager.cc | 10 +- .../collection/balance/lb_invoke/lb_manager.h | 7 +- src/vt/vrt/collection/balance/stats_msg.h | 21 --- .../balance/temperedlb/tempered_msgs.h | 40 ----- .../balance/temperedlb/temperedlb.cc | 40 ++--- .../balance/temperedlb/temperedlb.h | 8 +- src/vt/vrt/collection/reducable/reducable.h | 68 +++++++++ .../vrt/collection/reducable/reducable.impl.h | 93 +++++++++++- .../test_collection_group.extended.cc | 30 ++-- .../test_collection_group_recreate.cc | 72 ++++----- .../unit/collection/test_reduce_collection.cc | 26 +--- .../test_reduce_collection_common.h | 88 ++--------- .../test_reduce_collection_handler.h | 137 ++---------------- .../collection/test_reduce_collection_race.cc | 20 +-- .../collectives/test_collectives_reduce.cc | 21 +-- tests/unit/lb/test_lb_data_comm.cc | 32 +--- tests/unit/objgroup/test_objgroup.cc | 60 +++----- tests/unit/objgroup/test_objgroup_common.h | 25 ---- tutorial/tutorial_2b.h | 39 ++--- 45 files changed, 772 insertions(+), 714 deletions(-) create mode 100644 src/vt/collective/reduce/get_reduce_stamp.h diff --git a/examples/collection/jacobi1d_vt.cc b/examples/collection/jacobi1d_vt.cc index bd4b301761..3a10699029 100644 --- a/examples/collection/jacobi1d_vt.cc +++ b/examples/collection/jacobi1d_vt.cc @@ -73,9 +73,7 @@ static constexpr double const default_tol = 1.0e-02; struct NodeObj { bool is_finished_ = false; - struct WorkFinishedMsg : vt::Message {}; - - void workFinishedHandler(WorkFinishedMsg*) { is_finished_ = true; } + void workFinishedHandler() { is_finished_ = true; } bool isWorkFinished() { return is_finished_; } }; using NodeObjProxy = vt::objgroup::proxy::Proxy; @@ -119,19 +117,12 @@ struct LinearPb1DJacobi : vt::Collection { }; - struct ReduxMsg : vt::collective::ReduceTMsg { - ReduxMsg() = default; - explicit ReduxMsg(double in_val) : ReduceTMsg(in_val) { } - }; - - void checkCompleteCB(ReduxMsg* msg) { + void checkCompleteCB(double normRes) { // // Only one object for the reduction will visit // this function // - double normRes = msg->getConstVal(); - auto const iter_max_reached = iter_ > maxIter_; auto const norm_res_done = normRes < default_tol; @@ -143,7 +134,7 @@ struct LinearPb1DJacobi : vt::Collection { fmt::print(to_print); // Notify all nodes that computation is finished - objProxy_.broadcast(); + objProxy_.broadcast<&NodeObj::workFinishedHandler>(); } else { fmt::print(" ## ITER {} >> Residual Norm = {} \n", iter_, normRes); } @@ -184,9 +175,9 @@ struct LinearPb1DJacobi : vt::Collection { } auto proxy = this->getCollectionProxy(); - auto cb = vt::theCB()->makeSend<&LinearPb1DJacobi::checkCompleteCB>(proxy[0]); - auto msg2 = vt::makeMessage(maxNorm); - proxy.reduce>(msg2.get(),cb); + proxy.reduce<&LinearPb1DJacobi::checkCompleteCB, vt::collective::MaxOp>( + proxy[0], maxNorm + ); } diff --git a/examples/collection/jacobi2d_vt.cc b/examples/collection/jacobi2d_vt.cc index a24fd7bca6..eba9608c43 100644 --- a/examples/collection/jacobi2d_vt.cc +++ b/examples/collection/jacobi2d_vt.cc @@ -80,9 +80,7 @@ static constexpr double const default_tol = 1.0e-02; struct NodeObj { bool is_finished_ = false; - struct WorkFinishedMsg : vt::Message {}; - - void workFinishedHandler(WorkFinishedMsg*) { is_finished_ = true; } + void workFinishedHandler() { is_finished_ = true; } bool isWorkFinished() { return is_finished_; } }; using NodeObjProxy = vt::objgroup::proxy::Proxy; @@ -129,19 +127,12 @@ struct LinearPb2DJacobi : vt::Collection { }; - struct ReduxMsg : vt::collective::ReduceTMsg { - ReduxMsg() = default; - explicit ReduxMsg(double in_val) : ReduceTMsg(in_val) { } - }; - - - void checkCompleteCB(ReduxMsg* msg) { + void checkCompleteCB(double const normRes) { // // Only one object for the reduction will visit // this function // - const double normRes = msg->getConstVal(); auto const iter_max_reached = iter_ > maxIter_; auto const norm_res_done = normRes < default_tol; @@ -153,7 +144,7 @@ struct LinearPb2DJacobi : vt::Collection { fmt::print(to_print); // Notify all nodes that computation is finished - objProxy_.broadcast(); + objProxy_.broadcast<&NodeObj::workFinishedHandler>(); } else { fmt::print(" ## ITER {} >> Residual Norm = {} \n", iter_, normRes); } @@ -222,9 +213,9 @@ struct LinearPb2DJacobi : vt::Collection { } auto proxy = this->getCollectionProxy(); - auto cb = vt::theCB()->makeSend<&LinearPb2DJacobi::checkCompleteCB>(proxy(0,0)); - auto msg2 = vt::makeMessage(maxNorm); - proxy.reduce>(msg2.get(),cb); + proxy.reduce<&LinearPb2DJacobi::checkCompleteCB, vt::collective::MaxOp>( + proxy(0,0), maxNorm + ); } struct VecMsg : vt::CollectionMessage { diff --git a/examples/collection/reduce_integral.cc b/examples/collection/reduce_integral.cc index 45a40fbcb5..a88c06d229 100644 --- a/examples/collection/reduce_integral.cc +++ b/examples/collection/reduce_integral.cc @@ -101,21 +101,17 @@ struct Integration1D : vt::Collection { numPartsPerObject_(default_nparts_object) { } - struct CheckIntegral { - - void operator()(ReduceMsg* msg) { - fmt::print(" >> The integral over [0, 1] is {}\n", msg->getConstVal()); - fmt::print( - " >> The absolute error is {}\n", - std::fabs(msg->getConstVal() - exactIntegral) - ); - // - // Set the 'root_reduce_finished' variable to true. - // - root_reduce_finished = true; - } - - }; + void checkIntegral(double val) { + fmt::print(" >> The integral over [0, 1] is {}\n", val); + fmt::print( + " >> The absolute error is {}\n", + std::fabs(val - exactIntegral) + ); + // + // Set the 'root_reduce_finished' variable to true. + // + root_reduce_finished = true; + } struct InitMsg : vt::CollectionMessage { @@ -177,9 +173,9 @@ struct Integration1D : vt::Collection { // auto proxy = this->getCollectionProxy(); - auto msgCB = vt::makeMessage(quadsum); - auto cback = vt::theCB()->makeSend(reduce_root_node); - proxy.reduce>(msgCB.get(),cback); + proxy.reduce<&Integration1D::checkIntegral, vt::collective::PlusOp>( + proxy[0], quadsum + ); } }; diff --git a/examples/collection/transpose.cc b/examples/collection/transpose.cc index 6baf3e21c3..0c016c876d 100644 --- a/examples/collection/transpose.cc +++ b/examples/collection/transpose.cc @@ -58,8 +58,6 @@ struct RequestDataMsg : vt::CollectionMessage { vt::NodeType node_; }; -struct InitMsg : vt::collective::ReduceNoneMsg { }; - struct DataMsg : vt::Message { using MessageParentType = vt::Message; vt_msg_serialize_required(); // by payload_ @@ -170,7 +168,7 @@ struct Block : vt::Collection { ); } - void doneInit(InitMsg* msg) { + void doneInit() { if (getIndex().x() == 0) { auto proxy = this->getCollectionProxy(); auto proxy_msg = vt::makeMessage(proxy.getProxy()); @@ -183,9 +181,7 @@ struct Block : vt::Collection { initialize(); // Wait for all initializations to complete auto proxy = this->getCollectionProxy(); - auto cb = vt::theCB()->makeBcast<&Block::doneInit>(proxy); - auto empty = vt::makeMessage(); - proxy.reduce(empty.get(), cb); + proxy.allreduce<&Block::doneInit>(); } private: diff --git a/examples/hello_world/hello_world_collection_reduce.cc b/examples/hello_world/hello_world_collection_reduce.cc index d15deee20d..5bcee047b6 100644 --- a/examples/hello_world/hello_world_collection_reduce.cc +++ b/examples/hello_world/hello_world_collection_reduce.cc @@ -45,24 +45,19 @@ /// [Hello world reduce collection] struct Hello : vt::Collection { - using ReduceMsg = vt::collective::ReduceTMsg; - - void done(ReduceMsg* msg) { - fmt::print("Reduce complete at {} value {}\n", this->getIndex(), msg->getVal()); + void done(int val, double val2) { + fmt::print("Reduce complete at {} value {} {}\n", getIndex(), val, val2); } void doWork() { - fmt::print("Hello from {}\n", this->getIndex()); + fmt::print("Hello from {}\n", getIndex()); // Get the proxy for the collection - auto proxy = this->getCollectionProxy(); - - // Create a callback for when the reduction finishes - auto cb = vt::theCB()->makeSend<&Hello::done>(proxy(2)); + auto proxy = getCollectionProxy(); - // Create and send the reduction message holding an int - auto red_msg = vt::makeMessage(this->getIndex().x()); - proxy.reduce>(red_msg.get(),cb); + auto val = getIndex().x(); + auto val2 = 2.4; + proxy.allreduce<&Hello::done, vt::collective::PlusOp>(val, val2); } }; diff --git a/src/vt/collective/reduce/get_reduce_stamp.h b/src/vt/collective/reduce/get_reduce_stamp.h new file mode 100644 index 0000000000..50ca23008d --- /dev/null +++ b/src/vt/collective/reduce/get_reduce_stamp.h @@ -0,0 +1,99 @@ +/* +//@HEADER +// ***************************************************************************** +// +// get_reduce_stamp.h +// DARMA/vt => Virtual Transport +// +// Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC +// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S. +// Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact darma@sandia.gov +// +// ***************************************************************************** +//@HEADER +*/ + +#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_GET_REDUCE_STAMP_H +#define INCLUDED_VT_COLLECTIVE_REDUCE_GET_REDUCE_STAMP_H + +#include "vt/config.h" + +namespace vt { namespace collective { namespace reduce { + +template +struct GetReduceStamp : std::false_type { + template + static auto getMsg(Args&&... args) { + return vt::makeMessage(std::tuple{std::forward(args)...}); + } +}; + +template <> +struct GetReduceStamp< + std::enable_if_t> +> : std::false_type { + template + static auto getMsg() { + return vt::makeMessage(std::tuple<>{}); + } +}; + +template +struct GetReduceStamp< + std::enable_if_t< + std::is_same_v< + std::decay_t>>, + collective::reduce::ReduceStamp + > + >, + Args... +> : std::true_type { + template + static constexpr auto getMsgHelper( + std::tuple tp, std::index_sequence + ) { + return std::tuple{std::get(tp)...}; + } + + template + static auto getMsg(Args&&... args) { + return vt::makeMessage( + getMsgHelper( + std::tie(std::forward(args)...), + std::make_index_sequence{} + ) + ); + } +}; + +}}} /* end namespace vt::collective::reduce */ + +#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_GET_REDUCE_STAMP_H*/ diff --git a/src/vt/collective/reduce/operators/default_msg.h b/src/vt/collective/reduce/operators/default_msg.h index 1ed218e0ab..5945393a53 100644 --- a/src/vt/collective/reduce/operators/default_msg.h +++ b/src/vt/collective/reduce/operators/default_msg.h @@ -51,27 +51,53 @@ #include #include +#include namespace vt { namespace collective { namespace reduce { namespace operators { template struct ReduceCombine; +template +struct GetCallbackType; + +template +struct GetCallbackType { + using CallbackType = Callback; + using MsgT = T; +}; + +template +struct GetCallbackType> { + using CallbackType = Callback; +}; + +template +struct GetCallbackType> { + using MsgT = T; +}; + template struct ReduceDataMsg : SerializeIfNeeded< ReduceMsg, ReduceDataMsg, DataType >, ReduceCombine { + + using DataT = DataType; + using CallbackParamType = typename GetCallbackType::CallbackType; + using CallbackMsgType = Callback>; + using MessageParentType = SerializeIfNeeded< ReduceMsg, ReduceDataMsg, DataType >; - using CallbackType = CallbackU; - ReduceDataMsg() = default; + ReduceDataMsg(ReduceDataMsg const&) = default; + ReduceDataMsg(ReduceDataMsg&&) = default; + explicit ReduceDataMsg(DataType&& in_val) : MessageParentType(), ReduceCombine(), val_(std::forward(in_val)) @@ -84,46 +110,73 @@ struct ReduceDataMsg : SerializeIfNeeded< DataType const& getConstVal() const { return val_; } DataType& getVal() { return val_; } DataType&& getMoveVal() { return std::move(val_); } - CallbackType getCallback() { return cb_; } - template - void setCallback(Callback cb) { cb_ = CallbackType{cb}; } + bool isMsgCallback() const { return cb_.index() == 0; } + bool isParamCallback() const { return cb_.index() == 1; } + CallbackMsgType getMsgCallback() { return std::get<0>(cb_); } + CallbackParamType getParamCallback() { return std::get<1>(cb_); } + bool hasValidCallback() { + if (isMsgCallback()) { + return getMsgCallback().valid(); + } else { + return getParamCallback().valid(); + } + } + + template + void setCallback(CallbackT cb) { + if constexpr (std::is_same_v) { + cb_ = cb; + } else if ( + std::is_convertible_v< + typename GetCallbackType::MsgT*, ReduceDataMsg* + > + ) { + auto cb_ptr = reinterpret_cast>*>(&cb); + cb_ = *cb_ptr; + } else { + static_assert( + std::is_same_v or + std::is_convertible_v< + typename GetCallbackType::MsgT*, ReduceDataMsg* + >, + "Must be a convertible message callback or parameterized callback" + ); + } + } template void serialize(SerializeT& s) { MessageParentType::serialize(s); s | val_; - s | cb_; + int index = cb_.index(); + s | index; + if (s.isUnpacking()) { + if (index == 0) { + CallbackMsgType cb; + s | cb; + cb_ = cb; + } else { + CallbackParamType cb; + s | cb; + cb_ = cb; + } + } else { + if (index == 0) { + s | std::get<0>(cb_); + } else { + s | std::get<1>(cb_); + } + } } protected: DataType val_ = {}; - CallbackType cb_ = {}; + std::variant cb_; }; template -struct ReduceTMsg : SerializeIfNeeded< - ReduceDataMsg, - ReduceTMsg -> { - using MessageParentType = SerializeIfNeeded< - ReduceDataMsg, - ReduceTMsg - >; - - ReduceTMsg() = default; - explicit ReduceTMsg(T&& in_val) - : MessageParentType(std::forward(in_val)) - { } - explicit ReduceTMsg(T const& in_val) - : MessageParentType(in_val) - { } - - template - inline void serialize(SerializeT& s) { - MessageParentType::serialize(s); - } -}; +using ReduceTMsg = ReduceDataMsg; template struct ReduceArrMsg : SerializeIfNeeded< diff --git a/src/vt/collective/reduce/operators/default_op.impl.h b/src/vt/collective/reduce/operators/default_op.impl.h index a352ab94e9..fde20ff951 100644 --- a/src/vt/collective/reduce/operators/default_op.impl.h +++ b/src/vt/collective/reduce/operators/default_op.impl.h @@ -51,18 +51,31 @@ namespace vt { namespace collective { namespace reduce { namespace operators { struct NoCombine {}; +template +struct IsTuple : std::false_type {}; +template +struct IsTuple> : std::true_type {}; + template template /*static*/ void ReduceCombine::msgHandler(MsgT* msg) { if (msg->isRoot()) { - auto cb = msg->getCallback(); vt_debug_print( terse, reduce, - "ROOT: reduce root: valid={}, ptr={}\n", cb.valid(), print_ptr(msg) + "ROOT: reduce root: ptr={}\n", print_ptr(msg) ); - if (cb.valid()) { + if (msg->hasValidCallback()) { envelopeUnlockForForwarding(msg->env); - cb.template send(msg); + if (msg->isParamCallback()) { + if constexpr (IsTuple::value) { + msg->getParamCallback().sendTuple(std::move(msg->getVal())); + } + } else { + // We need to force the type the the more specific one here + auto cb = msg->getMsgCallback(); + auto typed_cb = reinterpret_cast*>(&cb); + typed_cb->send(msg); + } } else if (msg->root_handler_ != uninitialized_handler) { auto_registry::getAutoHandler(msg->root_handler_)->dispatch(msg, nullptr); } else { @@ -75,8 +88,7 @@ template MsgT* cur_msg = msg->template getNext(); vt_debug_print( verbose, reduce, - "leaf: fst valid={}, ptr={}\n", fst_msg->getCallback().valid(), - print_ptr(fst_msg) + "leaf: fst ptr={}\n", print_ptr(fst_msg) ); while (cur_msg != nullptr) { ReduceCombine<>::combine(fst_msg, cur_msg); diff --git a/src/vt/collective/reduce/reduce.h b/src/vt/collective/reduce/reduce.h index 3b0934521c..be969eebd1 100644 --- a/src/vt/collective/reduce/reduce.h +++ b/src/vt/collective/reduce/reduce.h @@ -171,25 +171,16 @@ struct Reduce : virtual collective::tree::Tree { ////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////// - template - struct FunctionTraitsArgs; - - template - struct FunctionTraitsArgs { - using TupleType = std::tuple...>; - using ReturnType = ReturnT; - }; - template class Op, auto f, typename... Params> PendingSendType reduce(Node root, Params&&... params) { - using Tuple = typename FunctionTraitsArgs::TupleType; + using Tuple = typename FuncTraits::TupleType; using OpT = Op; return reduce(root, std::forward(params)...); } template PendingSendType reduce(Node root, Params&&... params) { - using Tuple = typename FunctionTraitsArgs::TupleType; + using Tuple = typename FuncTraits::TupleType; using MsgT = ReduceTMsg; auto msg = vt::makeMessage(std::tuple{std::forward(params)...}); diff --git a/src/vt/objgroup/proxy/proxy_objgroup.h b/src/vt/objgroup/proxy/proxy_objgroup.h index b235bdcf95..1d4c34b525 100644 --- a/src/vt/objgroup/proxy/proxy_objgroup.h +++ b/src/vt/objgroup/proxy/proxy_objgroup.h @@ -94,7 +94,6 @@ struct Proxy { { } public: - /** * \brief Broadcast a message to all nodes to be delivered to the local object * instance @@ -160,6 +159,64 @@ struct Proxy { template PendingSendType broadcast(Args&&... args) const; + /** + * \brief All-reduce back to this objgroup. Performs a reduction using + * operator `Op` followed by a broadcast to `f` with the result. + * + * \param[in] args the arguments to reduce. \note The last argument optionally + * may be a `ReduceStamp`. + * + * \return a pending send + */ + template < + auto f, + template class Op = collective::NoneOp, + typename... Args + > + PendingSendType allreduce( + Args&&... args + ) const; + + /** + * \brief Reduce back to a point target. Performs a reduction using operator + * `Op` followed by a send to `f` with the result. + * + * \param[in] args the arguments to reduce. \note The last argument optionally + * may be a `ReduceStamp`. + * + * \return a pending send + */ + template < + auto f, + template class Op = collective::NoneOp, + typename Target, + typename... Args + > + PendingSendType reduce( + Target target, + Args&&... args + ) const; + + /** + * \brief Reduce back to an arbitrary callback. Performs a reduction using + * operator `Op` and then delivers the result to the callback `cb`. + * + * \param[in] cb the callback to trigger with the reduction result + * \param[in] args the arguments to reduce. \note The last argument optionally + * may be a `ReduceStamp`. + * + * \return a pending send + */ + template < + template class Op = collective::NoneOp, + typename... CBArgs, + typename... Args + > + PendingSendType reduce( + vt::Callback cb, + Args&&... args + ) const; + /** * \brief Reduce over the objgroup instances on each node with a callback * target. @@ -176,6 +233,7 @@ struct Proxy { typename MsgT, ActiveTypedFnType *f > + [[deprecated("Use new interface calls (allreduce/reduce) without message")]] PendingSendType reduce( MsgPtrT msg, Callback cb, ReduceStamp stamp = ReduceStamp{} ) const; @@ -185,6 +243,7 @@ struct Proxy { typename MsgPtrT, typename MsgT = typename util::MsgPtrType::MsgType > + [[deprecated("Use new interface calls (allreduce/reduce) without message")]] PendingSendType reduce( MsgPtrT msg, Callback cb, ReduceStamp stamp = ReduceStamp{} ) const { @@ -214,13 +273,16 @@ struct Proxy { typename MsgT = typename util::MsgPtrType::MsgType, ActiveTypedFnType *f > + [[deprecated("Use new interface calls (allreduce/reduce) without message")]] PendingSendType reduce(MsgPtrT msg, ReduceStamp stamp = ReduceStamp{}) const; + template < typename OpT = collective::None, typename FunctorT, typename MsgPtrT, typename MsgT = typename util::MsgPtrType::MsgType > + [[deprecated("Use new interface calls (allreduce/reduce) without message")]] PendingSendType reduce(MsgPtrT msg, ReduceStamp stamp = ReduceStamp{}) const { return reduce< @@ -246,6 +308,7 @@ struct Proxy { typename MsgT = typename util::MsgPtrType::MsgType, ActiveTypedFnType *f > + [[deprecated("Use new interface calls (allreduce/reduce) without message")]] PendingSendType reduce(MsgPtrT msg, ReduceStamp stamp = ReduceStamp{}) const; /** diff --git a/src/vt/objgroup/proxy/proxy_objgroup.impl.h b/src/vt/objgroup/proxy/proxy_objgroup.impl.h index 46cf20ad69..21c380608f 100644 --- a/src/vt/objgroup/proxy/proxy_objgroup.impl.h +++ b/src/vt/objgroup/proxy/proxy_objgroup.impl.h @@ -54,6 +54,7 @@ #include "vt/rdmahandle/manager.h" #include "vt/messaging/param_msg.h" #include "vt/objgroup/proxy/proxy_bits.h" +#include "vt/collective/reduce/get_reduce_stamp.h" namespace vt { namespace objgroup { namespace proxy { @@ -107,6 +108,106 @@ Proxy::broadcast(Params&&... params) const { return typename Proxy::PendingSendType{std::nullptr_t{}}; } + +template +template < + auto f, + template class Op, + typename... Args +> +typename Proxy::PendingSendType +Proxy::allreduce( + Args&&... args +) const { + using Tuple = typename FuncTraits::TupleType; + using MsgT = collective::ReduceTMsg; + using GetReduceStamp = collective::reduce::GetReduceStamp; + auto cb = theCB()->makeBcast(*this); + + ReduceStamp stamp{}; + if constexpr (GetReduceStamp::value) { + stamp = std::get(std::tie(std::forward(args)...)); + } + + auto msg = GetReduceStamp::template getMsg(std::forward(args)...); + msg->setCallback(cb); + auto proxy = Proxy(*this); + return theObjGroup()->reduce< + ObjT, + MsgT, + &MsgT::template msgHandler< + MsgT, Op, collective::reduce::operators::ReduceCallback + > + >(proxy, msg.get(), stamp); +} + +template +template < + auto f, + template class Op, + typename Target, + typename... Args +> +typename Proxy::PendingSendType +Proxy::reduce( + Target target, + Args&&... args +) const { + using Tuple = typename FuncTraits::TupleType; + using MsgT = collective::ReduceTMsg; + using GetReduceStamp = collective::reduce::GetReduceStamp; + auto cb = theCB()->makeSend(target); + + ReduceStamp stamp{}; + if constexpr (GetReduceStamp::value ) { + stamp = std::get(std::tie(std::forward(args)...)); + } + + auto msg = GetReduceStamp::template getMsg(std::forward(args)...); + msg->setCallback(cb); + auto proxy = Proxy(*this); + return theObjGroup()->reduce< + ObjT, + MsgT, + &MsgT::template msgHandler< + MsgT, Op, collective::reduce::operators::ReduceCallback + > + >(proxy, msg.get(), stamp); +} + +template +template < + template class Op, + typename... CBArgs, + typename... Args +> +typename Proxy::PendingSendType +Proxy::reduce( + vt::Callback cb, + Args&&... args +) const { + using CallbackT = vt::Callback; + using Tuple = typename CallbackT::TupleType; + using MsgT = collective::ReduceTMsg; + using GetReduceStamp = collective::reduce::GetReduceStamp; + + ReduceStamp stamp{}; + if constexpr (GetReduceStamp::value ) { + stamp = std::get(std::tie(std::forward(args)...)); + } + + auto msg = GetReduceStamp::template getMsg(std::forward(args)...); + msg->setCallback(cb); + auto proxy = Proxy(*this); + return theObjGroup()->reduce< + ObjT, + MsgT, + &MsgT::template msgHandler< + MsgT, Op, collective::reduce::operators::ReduceCallback + > + >(proxy, msg.get(), stamp); +} + template template < typename OpT, typename MsgPtrT, typename MsgT, ActiveTypedFnType *f diff --git a/src/vt/phase/phase_manager.cc b/src/vt/phase/phase_manager.cc index 70fceabf29..8ab41ab086 100644 --- a/src/vt/phase/phase_manager.cc +++ b/src/vt/phase/phase_manager.cc @@ -156,8 +156,6 @@ void PhaseManager::startup() { runHooks(PhaseHook::Start); } -struct NextMsg : collective::ReduceNoneMsg {}; - void PhaseManager::nextPhaseCollective() { vtAbortIf( in_next_phase_collective_, @@ -175,9 +173,7 @@ void PhaseManager::nextPhaseCollective() { auto proxy = objgroup::proxy::Proxy(proxy_); // Start with a reduction to sure all nodes are ready for this - auto cb = theCB()->makeBcast<&PhaseManager::nextPhaseReduce>(proxy); - auto msg = makeMessage(); - proxy.reduce(msg.get(), cb); + proxy.allreduce<&PhaseManager::nextPhaseReduce>(); theSched()->runSchedulerWhile([this]{ return not reduce_next_phase_done_; }); reduce_next_phase_done_ = false; @@ -203,9 +199,7 @@ void PhaseManager::nextPhaseCollective() { runHooks(PhaseHook::Start); // Start with a reduction to sure all nodes are ready for this - auto cb2 = theCB()->makeBcast<&PhaseManager::nextPhaseDone>(proxy); - auto msg2 = makeMessage(); - proxy.reduce(msg2.get(), cb2); + proxy.allreduce<&PhaseManager::nextPhaseDone>(); theSched()->runSchedulerWhile([this]{ return not reduce_finished_; }); reduce_finished_ = false; @@ -213,11 +207,11 @@ void PhaseManager::nextPhaseCollective() { in_next_phase_collective_ = false; } -void PhaseManager::nextPhaseReduce(NextMsg* msg) { +void PhaseManager::nextPhaseReduce() { reduce_next_phase_done_ = true; } -void PhaseManager::nextPhaseDone(NextMsg* msg) { +void PhaseManager::nextPhaseDone() { reduce_finished_ = true; } diff --git a/src/vt/phase/phase_manager.h b/src/vt/phase/phase_manager.h index 33c9e0a1a7..62f209a100 100644 --- a/src/vt/phase/phase_manager.h +++ b/src/vt/phase/phase_manager.h @@ -57,9 +57,6 @@ namespace vt { namespace phase { -// fwd-decl for reduce message -struct NextMsg; - /** * \struct PhaseManager * @@ -186,7 +183,7 @@ struct PhaseManager : runtime::component::Component { * * \param[in] msg the (empty) next phase message */ - void nextPhaseReduce(NextMsg* msg); + void nextPhaseReduce(); /** * \internal @@ -194,7 +191,7 @@ struct PhaseManager : runtime::component::Component { * * \param[in] msg the (empty) next phase message */ - void nextPhaseDone(NextMsg* msg); + void nextPhaseDone(); /** * \internal diff --git a/src/vt/pipe/callback/cb_union/cb_raw_base.h b/src/vt/pipe/callback/cb_union/cb_raw_base.h index 00727c9d5d..9eb619865c 100644 --- a/src/vt/pipe/callback/cb_union/cb_raw_base.h +++ b/src/vt/pipe/callback/cb_union/cb_raw_base.h @@ -161,6 +161,8 @@ struct CallbackRawBaseSingle { template struct CallbackTyped : CallbackRawBaseSingle { + using TupleType = std::tuple; + CallbackTyped() = default; CallbackTyped(CallbackTyped const&) = default; CallbackTyped(CallbackTyped&&) = default; @@ -228,6 +230,14 @@ struct CallbackTyped : CallbackRawBaseSingle { return other.pipe_ == pipe_ && other.cb_ == cb_; } + template + void sendTuple(std::tuple tup) { + using Trait = CBTraits; + using MsgT = messaging::ParamMsg; + auto msg = vt::makeMessage(std::move(tup)); + CallbackRawBaseSingle::sendMsg(msg); + } + template void send(Params&&... params) { using Trait = CBTraits; diff --git a/src/vt/pipe/pipe_manager_tl.impl.h b/src/vt/pipe/pipe_manager_tl.impl.h index 34a9c71287..529b728b9c 100644 --- a/src/vt/pipe/pipe_manager_tl.impl.h +++ b/src/vt/pipe/pipe_manager_tl.impl.h @@ -204,7 +204,7 @@ auto PipeManagerTL::makeCallbackProxy(ProxyT proxy) { HandlerType han = uninitialized_handler; if constexpr (std::is_same_v) { using Tuple = typename Trait::TupleType; - using PMsgT = messaging::ParamMsg; + using PMsgT = messaging::ParamMsg; han = auto_registry::makeAutoHandlerObjGroupParam< ObjT, decltype(f), f, PMsgT >(ctrl); diff --git a/src/vt/termination/graph/epoch_graph_reduce.h b/src/vt/termination/graph/epoch_graph_reduce.h index 09a5f363aa..184257eb64 100644 --- a/src/vt/termination/graph/epoch_graph_reduce.h +++ b/src/vt/termination/graph/epoch_graph_reduce.h @@ -49,7 +49,7 @@ namespace vt { namespace collective { namespace reduce { namespace operators { template -struct ReduceTMsg; +struct ReduceDataMsg; }}}} /* end namespace vt::collective::reduce::operators */ @@ -58,14 +58,14 @@ namespace vt { namespace termination { namespace graph { // Must be templated (where T = graph::EpochGraph) because of the circular // dependency between termination.h and reduce.h template -struct EpochGraphMsg : collective::reduce::operators::ReduceTMsg { - using MessageParentType = collective::reduce::operators::ReduceTMsg; +struct EpochGraphMsg : collective::reduce::operators::ReduceDataMsg { + using MessageParentType = collective::reduce::operators::ReduceDataMsg; vt_msg_serialize_if_needed_by_parent(); EpochGraphMsg() = default; explicit EpochGraphMsg(std::shared_ptr const& graph) - : collective::reduce::operators::ReduceTMsg(*graph) + : collective::reduce::operators::ReduceDataMsg(*graph) { } template diff --git a/src/vt/utils/fntraits/fntraits.h b/src/vt/utils/fntraits/fntraits.h index 768af84651..7ae429bcbd 100644 --- a/src/vt/utils/fntraits/fntraits.h +++ b/src/vt/utils/fntraits/fntraits.h @@ -54,15 +54,17 @@ struct ObjFuncTraitsImpl; template struct ObjFuncTraitsImpl< std::enable_if_t< - std::is_convertible::value or + (std::is_convertible::value or std::is_convertible::value or std::is_convertible::value or - std::is_convertible::value + std::is_convertible::value) + and + std::is_pointer::value >, - Return(*)(Obj*, Msg*) + Return(*)(Obj, Msg*) > { static constexpr bool is_member = false; - using ObjT = Obj; + using ObjT = std::remove_pointer_t; using MsgT = Msg; using ReturnT = Return; template