Skip to content

Commit

Permalink
#276: pipe: finish implementation across objgroup and collections
Browse files Browse the repository at this point in the history
  • Loading branch information
lifflander committed Jun 13, 2023
1 parent dba9fc1 commit 979c52f
Show file tree
Hide file tree
Showing 45 changed files with 772 additions and 714 deletions.
21 changes: 6 additions & 15 deletions examples/collection/jacobi1d_vt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ static constexpr double const default_tol = 1.0e-02;

struct NodeObj {
bool is_finished_ = false;
struct WorkFinishedMsg : vt::Message {};

void workFinishedHandler(WorkFinishedMsg*) { is_finished_ = true; }
void workFinishedHandler() { is_finished_ = true; }
bool isWorkFinished() { return is_finished_; }
};
using NodeObjProxy = vt::objgroup::proxy::Proxy<NodeObj>;
Expand Down Expand Up @@ -119,19 +117,12 @@ struct LinearPb1DJacobi : vt::Collection<LinearPb1DJacobi,vt::Index1D> {

};

struct ReduxMsg : vt::collective::ReduceTMsg<double> {
ReduxMsg() = default;
explicit ReduxMsg(double in_val) : ReduceTMsg<double>(in_val) { }
};

void checkCompleteCB(ReduxMsg* msg) {
void checkCompleteCB(double normRes) {
//
// Only one object for the reduction will visit
// this function
//

double normRes = msg->getConstVal();

auto const iter_max_reached = iter_ > maxIter_;
auto const norm_res_done = normRes < default_tol;

Expand All @@ -143,7 +134,7 @@ struct LinearPb1DJacobi : vt::Collection<LinearPb1DJacobi,vt::Index1D> {
fmt::print(to_print);

// Notify all nodes that computation is finished
objProxy_.broadcast<NodeObj::WorkFinishedMsg, &NodeObj::workFinishedHandler>();
objProxy_.broadcast<&NodeObj::workFinishedHandler>();
} else {
fmt::print(" ## ITER {} >> Residual Norm = {} \n", iter_, normRes);
}
Expand Down Expand Up @@ -184,9 +175,9 @@ struct LinearPb1DJacobi : vt::Collection<LinearPb1DJacobi,vt::Index1D> {
}

auto proxy = this->getCollectionProxy();
auto cb = vt::theCB()->makeSend<&LinearPb1DJacobi::checkCompleteCB>(proxy[0]);
auto msg2 = vt::makeMessage<ReduxMsg>(maxNorm);
proxy.reduce<vt::collective::MaxOp<double>>(msg2.get(),cb);
proxy.reduce<&LinearPb1DJacobi::checkCompleteCB, vt::collective::MaxOp>(
proxy[0], maxNorm
);

}

Expand Down
21 changes: 6 additions & 15 deletions examples/collection/jacobi2d_vt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ static constexpr double const default_tol = 1.0e-02;

struct NodeObj {
bool is_finished_ = false;
struct WorkFinishedMsg : vt::Message {};

void workFinishedHandler(WorkFinishedMsg*) { is_finished_ = true; }
void workFinishedHandler() { is_finished_ = true; }
bool isWorkFinished() { return is_finished_; }
};
using NodeObjProxy = vt::objgroup::proxy::Proxy<NodeObj>;
Expand Down Expand Up @@ -129,19 +127,12 @@ struct LinearPb2DJacobi : vt::Collection<LinearPb2DJacobi,vt::Index2D> {
};


struct ReduxMsg : vt::collective::ReduceTMsg<double> {
ReduxMsg() = default;
explicit ReduxMsg(double in_val) : ReduceTMsg<double>(in_val) { }
};


void checkCompleteCB(ReduxMsg* msg) {
void checkCompleteCB(double const normRes) {
//
// Only one object for the reduction will visit
// this function
//

const double normRes = msg->getConstVal();
auto const iter_max_reached = iter_ > maxIter_;
auto const norm_res_done = normRes < default_tol;

Expand All @@ -153,7 +144,7 @@ struct LinearPb2DJacobi : vt::Collection<LinearPb2DJacobi,vt::Index2D> {
fmt::print(to_print);

// Notify all nodes that computation is finished
objProxy_.broadcast<NodeObj::WorkFinishedMsg, &NodeObj::workFinishedHandler>();
objProxy_.broadcast<&NodeObj::workFinishedHandler>();
} else {
fmt::print(" ## ITER {} >> Residual Norm = {} \n", iter_, normRes);
}
Expand Down Expand Up @@ -222,9 +213,9 @@ struct LinearPb2DJacobi : vt::Collection<LinearPb2DJacobi,vt::Index2D> {
}

auto proxy = this->getCollectionProxy();
auto cb = vt::theCB()->makeSend<&LinearPb2DJacobi::checkCompleteCB>(proxy(0,0));
auto msg2 = vt::makeMessage<ReduxMsg>(maxNorm);
proxy.reduce<vt::collective::MaxOp<double>>(msg2.get(),cb);
proxy.reduce<&LinearPb2DJacobi::checkCompleteCB, vt::collective::MaxOp>(
proxy(0,0), maxNorm
);
}

struct VecMsg : vt::CollectionMessage<LinearPb2DJacobi> {
Expand Down
32 changes: 14 additions & 18 deletions examples/collection/reduce_integral.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,17 @@ struct Integration1D : vt::Collection<Integration1D, vt::Index1D> {
numPartsPerObject_(default_nparts_object)
{ }

struct CheckIntegral {

void operator()(ReduceMsg* msg) {
fmt::print(" >> The integral over [0, 1] is {}\n", msg->getConstVal());
fmt::print(
" >> The absolute error is {}\n",
std::fabs(msg->getConstVal() - exactIntegral)
);
//
// Set the 'root_reduce_finished' variable to true.
//
root_reduce_finished = true;
}

};
void checkIntegral(double val) {
fmt::print(" >> The integral over [0, 1] is {}\n", val);
fmt::print(
" >> The absolute error is {}\n",
std::fabs(val - exactIntegral)
);
//
// Set the 'root_reduce_finished' variable to true.
//
root_reduce_finished = true;
}

struct InitMsg : vt::CollectionMessage<Integration1D> {

Expand Down Expand Up @@ -177,9 +173,9 @@ struct Integration1D : vt::Collection<Integration1D, vt::Index1D> {
//

auto proxy = this->getCollectionProxy();
auto msgCB = vt::makeMessage<ReduceMsg>(quadsum);
auto cback = vt::theCB()->makeSend<CheckIntegral>(reduce_root_node);
proxy.reduce<vt::collective::PlusOp<double>>(msgCB.get(),cback);
proxy.reduce<&Integration1D::checkIntegral, vt::collective::PlusOp>(
proxy[0], quadsum
);
}

};
Expand Down
8 changes: 2 additions & 6 deletions examples/collection/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ struct RequestDataMsg : vt::CollectionMessage<ColT> {
vt::NodeType node_;
};

struct InitMsg : vt::collective::ReduceNoneMsg { };

struct DataMsg : vt::Message {
using MessageParentType = vt::Message;
vt_msg_serialize_required(); // by payload_
Expand Down Expand Up @@ -170,7 +168,7 @@ struct Block : vt::Collection<Block, vt::Index1D> {
);
}

void doneInit(InitMsg* msg) {
void doneInit() {
if (getIndex().x() == 0) {
auto proxy = this->getCollectionProxy();
auto proxy_msg = vt::makeMessage<ProxyMsg>(proxy.getProxy());
Expand All @@ -183,9 +181,7 @@ struct Block : vt::Collection<Block, vt::Index1D> {
initialize();
// Wait for all initializations to complete
auto proxy = this->getCollectionProxy();
auto cb = vt::theCB()->makeBcast<&Block::doneInit>(proxy);
auto empty = vt::makeMessage<InitMsg>();
proxy.reduce(empty.get(), cb);
proxy.allreduce<&Block::doneInit>();
}

private:
Expand Down
19 changes: 7 additions & 12 deletions examples/hello_world/hello_world_collection_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,19 @@

/// [Hello world reduce collection]
struct Hello : vt::Collection<Hello, vt::Index1D> {
using ReduceMsg = vt::collective::ReduceTMsg<int>;

void done(ReduceMsg* msg) {
fmt::print("Reduce complete at {} value {}\n", this->getIndex(), msg->getVal());
void done(int val, double val2) {
fmt::print("Reduce complete at {} value {} {}\n", getIndex(), val, val2);
}

void doWork() {
fmt::print("Hello from {}\n", this->getIndex());
fmt::print("Hello from {}\n", getIndex());

// Get the proxy for the collection
auto proxy = this->getCollectionProxy();

// Create a callback for when the reduction finishes
auto cb = vt::theCB()->makeSend<&Hello::done>(proxy(2));
auto proxy = getCollectionProxy();

// Create and send the reduction message holding an int
auto red_msg = vt::makeMessage<ReduceMsg>(this->getIndex().x());
proxy.reduce<vt::collective::PlusOp<int>>(red_msg.get(),cb);
auto val = getIndex().x();
auto val2 = 2.4;
proxy.allreduce<&Hello::done, vt::collective::PlusOp>(val, val2);
}
};

Expand Down
99 changes: 99 additions & 0 deletions src/vt/collective/reduce/get_reduce_stamp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
//@HEADER
// *****************************************************************************
//
// get_reduce_stamp.h
// DARMA/vt => Virtual Transport
//
// Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC
// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S.
// Government retains certain rights in this software.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from this
// software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact darma@sandia.gov
//
// *****************************************************************************
//@HEADER
*/

#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_GET_REDUCE_STAMP_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_GET_REDUCE_STAMP_H

#include "vt/config.h"

namespace vt { namespace collective { namespace reduce {

template <typename enable = void, typename... Args>
struct GetReduceStamp : std::false_type {
template <typename MsgT>
static auto getMsg(Args&&... args) {
return vt::makeMessage<MsgT>(std::tuple{std::forward<Args>(args)...});
}
};

template <>
struct GetReduceStamp<
std::enable_if_t<std::is_same_v<void, void>>
> : std::false_type {
template <typename MsgT>
static auto getMsg() {
return vt::makeMessage<MsgT>(std::tuple<>{});
}
};

template <typename... Args>
struct GetReduceStamp<
std::enable_if_t<
std::is_same_v<
std::decay_t<std::tuple_element_t<sizeof...(Args) - 1, std::tuple<Args...>>>,
collective::reduce::ReduceStamp
>
>,
Args...
> : std::true_type {
template <typename... Params, std::size_t... Is>
static constexpr auto getMsgHelper(
std::tuple<Params...> tp, std::index_sequence<Is...>
) {
return std::tuple{std::get<Is>(tp)...};
}

template <typename MsgT>
static auto getMsg(Args&&... args) {
return vt::makeMessage<MsgT>(
getMsgHelper(
std::tie(std::forward<Args>(args)...),
std::make_index_sequence<sizeof...(Args) - 1>{}
)
);
}
};

}}} /* end namespace vt::collective::reduce */

#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_GET_REDUCE_STAMP_H*/
Loading

0 comments on commit 979c52f

Please sign in to comment.