Skip to content

Commit

Permalink
#276: tutorial: update and cleanup with new callback/reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
lifflander committed Jun 13, 2023
1 parent 27ba5fa commit 6ae9090
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 89 deletions.
66 changes: 11 additions & 55 deletions tutorial/tutorial_1g.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,45 +46,24 @@
namespace vt { namespace tutorial {

/// [Tutorial1G]
// VT Base Message
// \----------------/
// \ /
struct DataMsg : ::vt::Message { };

struct MsgWithCallback : ::vt::Message {
MsgWithCallback() = default;
explicit MsgWithCallback(Callback<DataMsg> in_cb) : cb(in_cb) {}

Callback<DataMsg> cb;
};


// Forward declaration for the active message handler
static void getCallbackHandler(MsgWithCallback* msg);
static void getCallbackHandler(Callback<int> cb);

// An active message handler used as the target for a callback
static void callbackHandler(DataMsg* msg) {
static void callbackHandler(int data) {
NodeType const cur_node = ::vt::theContext()->getNode();
::fmt::print("{}: triggering active message callback\n", cur_node);
::fmt::print("{}: triggering active message callback: {}\n", cur_node, data);
}

// An active message handler used as the target for a callback
static void callbackBcastHandler(DataMsg* msg) {
NodeType const cur_node = ::vt::theContext()->getNode();
::fmt::print("{}: triggering active message callback bcast\n", cur_node);
}

// A simple context object
struct MyContext { };
static MyContext ctx = {};

// A message handler with context used as the target for a callback
static void callbackCtx(DataMsg* msg, MyContext* cbctx) {
static void callbackBcastHandler(int data) {
NodeType const cur_node = ::vt::theContext()->getNode();
::fmt::print("{}: triggering context callback\n", cur_node);
::fmt::print(
"{}: triggering active message callback bcast: {}\n", cur_node, data
);
}


// Tutorial code to demonstrate using a callback
static inline void activeMessageCallback() {
NodeType const this_node = ::vt::theContext()->getNode();
Expand All @@ -104,50 +83,27 @@ static inline void activeMessageCallback() {
// Node that we want to callback to execute on
NodeType const cb_node = 0;

// Example lambda callback (void)
auto fn = [=](DataMsg* msg){
::fmt::print("{}: triggering function callback\n", this_node);
};

// Example of a void lambda callback
{
auto cb = ::vt::theCB()->makeFunc<DataMsg>(vt::pipe::LifetimeEnum::Once, fn);
auto msg = ::vt::makeMessage<MsgWithCallback>(cb);
::vt::theMsg()->sendMsg<getCallbackHandler>(to_node, msg);
}

// Example of active message handler callback with send node
{
auto cb = ::vt::theCB()->makeSend<callbackHandler>(cb_node);
auto msg = ::vt::makeMessage<MsgWithCallback>(cb);
::vt::theMsg()->sendMsg<getCallbackHandler>(to_node, msg);
::vt::theMsg()->send<getCallbackHandler>(Node{to_node}, cb);
}

// Example of active message handler callback with broadcast
{
auto cb = ::vt::theCB()->makeBcast<callbackBcastHandler>();
auto msg = ::vt::makeMessage<MsgWithCallback>(cb);
::vt::theMsg()->sendMsg<getCallbackHandler>(to_node, msg);
}

// Example of context callback
{
auto cb = ::vt::theCB()->makeFunc<DataMsg,MyContext>(
vt::pipe::LifetimeEnum::Once, &ctx, callbackCtx
);
auto msg = ::vt::makeMessage<MsgWithCallback>(cb);
::vt::theMsg()->sendMsg<getCallbackHandler>(to_node, msg);
::vt::theMsg()->send<getCallbackHandler>(Node{to_node}, cb);
}
}
}

// Message handler for to receive callback and invoke it
static void getCallbackHandler(MsgWithCallback* msg) {
static void getCallbackHandler(Callback<int> cb) {
auto const cur_node = ::vt::theContext()->getNode();
::fmt::print("getCallbackHandler: triggered on node={}\n", cur_node);

// Send the callback a message
msg->cb.send();
cb.send(29 + cur_node);
}
/// [Tutorial1G]

Expand Down
42 changes: 8 additions & 34 deletions tutorial/tutorial_1h.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,12 @@
namespace vt { namespace tutorial {

/// [Tutorial1H]
// Reduce Message VT Base Class
// \--------------------------------------------/
// \ /
// \ Reduce Data /
// \ \-----------/
// \ \ /
struct ReduceDataMsg : ::vt::collective::ReduceTMsg<int32_t> {};


// Functor that is the target of the reduction
struct ReduceResult {
void operator()(ReduceDataMsg* msg) {
NodeType const num_nodes = ::vt::theContext()->getNumNodes();
fmt::print("reduction value={}\n", msg->getConstVal());
assert(num_nodes * 50 == msg->getConstVal());
(void)num_nodes; // don't warn about unused value when not debugging
}
};

void reduceResult(int result) {
NodeType const num_nodes = ::vt::theContext()->getNumNodes();
(void)num_nodes; // don't warn about possibly unused variable
fmt::print("reduction value={}\n", result);
assert(num_nodes * 50 == result);
}

// Tutorial code to demonstrate using reduction on all nodes
static inline void activeMessageReduce() {
Expand All @@ -76,23 +63,10 @@ static inline void activeMessageReduce() {
/*
* Perform reduction over all the nodes.
*/

// This is the type of the reduction (uses the plus operator over the data
// type). Once can implement their own data type and overload the plus
// operator for the combine during the reduce
using ReduceOp = ::vt::collective::PlusOp<int32_t>;

NodeType const root_reduce_node = 0;

auto reduce_msg = ::vt::makeMessage<ReduceDataMsg>();

// Get a reference to the value to set it in this reduce msg
reduce_msg->getVal() = 50;

auto const default_proxy = theObjGroup()->getDefault();
default_proxy.reduceMsg<ReduceOp, ReduceResult>(
root_reduce_node, reduce_msg.get()
);
auto r = vt::theCollective()->global();
r->reduce<reduceResult, collective::PlusOp>(vt::Node{root_reduce_node}, 50);
}
/// [Tutorial1H]

Expand Down

0 comments on commit 6ae9090

Please sign in to comment.