Skip to content

Commit

Permalink
Merge pull request #2064 from DARMA-tasking/276-automatically-synthes…
Browse files Browse the repository at this point in the history
…ize-message-types-for-sends-to-arbitrary-handlers

276 Automatically synthesize and deduce message types for sends to arbitrary handlers
  • Loading branch information
lifflander authored Mar 15, 2023
2 parents 23ac286 + 6b69c91 commit edf105d
Show file tree
Hide file tree
Showing 140 changed files with 1,489 additions and 1,081 deletions.
8 changes: 0 additions & 8 deletions docs/md/param.md

This file was deleted.

1 change: 0 additions & 1 deletion docs/md/vt.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ management.
| \subpage location | `vt::theLocMan()` | \copybrief location | @m_class{m-label m-success} **Core** |
| \subpage mem-usage | `vt::theMemUsage()` | \copybrief mem-usage | @m_class{m-label m-warning} **Optional** |
| \subpage objgroup | `vt::theObjGroup()` | \copybrief objgroup | @m_class{m-label m-success} **Core** |
| \subpage param | `vt::theParam()` | \copybrief param | @m_class{m-label m-danger} **Experimental** |
| \subpage pipe | `vt::theCB()` | \copybrief pipe | @m_class{m-label m-success} **Core** |
| \subpage node-lb-data | `vt::theNodeLBData()` | \copybrief node-lb-data | @m_class{m-label m-warning} **Optional** |
| \subpage phase | `vt::thePhase()` | \copybrief phase | @m_class{m-label m-success} **Core** |
Expand Down
4 changes: 2 additions & 2 deletions examples/callback/callback.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ struct MyObj {
struct MyCol : vt::Collection<MyCol, vt::Index1D> { };

// Collection handler callback endpoint
void colHan(TestMsg* msg, MyCol* col) {
void colHan(MyCol* col, TestMsg* msg) {
printOutput(msg, "MyCol colHan (non-intrusive)");
}

void bounceCallback(vt::Callback<TestMsg> cb) {
auto msg = vt::makeMessage<HelloMsg>(cb);
vt::theMsg()->sendMsg<HelloMsg, hello_world>(1, msg);
vt::theMsg()->sendMsg<hello_world>(1, msg);
}

int main(int argc, char** argv) {
Expand Down
29 changes: 7 additions & 22 deletions examples/collection/lb_iter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,7 @@ static constexpr int32_t const default_num_elms = 64;
static int32_t num_iter = 8;

struct IterCol : vt::Collection<IterCol, vt::Index1D> {
IterCol() = default;

struct IterMsg : vt::CollectionMessage<IterCol> {
IterMsg() = default;
IterMsg(
int64_t const in_work_amt, int64_t const in_iter, int64_t const subphase
)
: iter_(in_iter), work_amt_(in_work_amt), subphase_(subphase)
{ }

int64_t iter_ = 0;
int64_t work_amt_ = 0;
int64_t subphase_ = 0;
};

void iterWork(IterMsg* msg);
void iterWork(int64_t work_amt, int64_t iter, int subphase);

template <typename SerializerT>
void serialize(SerializerT& s) {
Expand All @@ -76,10 +61,10 @@ struct IterCol : vt::Collection<IterCol, vt::Index1D> {

static double weight = 1.0f;

void IterCol::iterWork(IterMsg* msg) {
this->lb_data_.setSubPhase(msg->subphase_);
void IterCol::iterWork(int64_t work_amt, int64_t iter, int subphase) {
this->lb_data_.setSubPhase(subphase);
double val = 0.1f;
double val2 = 0.4f * msg->work_amt_;
double val2 = 0.4f * work_amt;
auto const idx = getIndex().x();
int64_t const max_work = 1000 * weight;
int64_t const mid_work = 100 * weight;
Expand Down Expand Up @@ -131,13 +116,13 @@ int main(int argc, char** argv) {
auto cur_time = vt::timing::getCurrentTime();

vt::runInEpochCollective([=]{
proxy.broadcastCollective<IterCol::IterMsg,&IterCol::iterWork>(10, i, 0);
proxy.broadcastCollective<&IterCol::iterWork>(10, i, 0);
});
vt::runInEpochCollective([=]{
proxy.broadcastCollective<IterCol::IterMsg,&IterCol::iterWork>(5, i, 1);
proxy.broadcastCollective<&IterCol::iterWork>(5, i, 1);
});
vt::runInEpochCollective([=]{
proxy.broadcastCollective<IterCol::IterMsg,&IterCol::iterWork>(15, i, 2);
proxy.broadcastCollective<&IterCol::iterWork>(15, i, 2);
});

auto total_time = vt::timing::getCurrentTime() - cur_time;
Expand Down
22 changes: 5 additions & 17 deletions examples/collection/migrate_collection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,12 @@ struct Hello : vt::Collection<Hello, vt::Index1D> {
double test_val = 0.0;
};

struct ColMsg : vt::CollectionMessage<Hello> {
explicit ColMsg(vt::NodeType const& in_from_node)
: from_node(in_from_node)
{ }

vt::NodeType from_node = vt::uninitialized_destination;
};

static void doWork(ColMsg* msg, Hello* col) {
static void doWork(Hello* col) {
vt::NodeType this_node = vt::theContext()->getNode();
fmt::print("{}: idx={}: val={}\n", this_node, col->getIndex(), col->test_val);
}

static void migrateToNext(ColMsg* msg, Hello* col) {
static void migrateToNext(Hello* col) {
vt::NodeType this_node = vt::theContext()->getNode();
vt::NodeType num_nodes = vt::theContext()->getNumNodes();
vt::NodeType next_node = (this_node + 1) % num_nodes;
Expand Down Expand Up @@ -109,13 +101,9 @@ int main(int argc, char** argv) {
.wait();

if (this_node == 0) {
vt::runInEpochRooted([=] { proxy.broadcast<ColMsg, doWork>(this_node); });

vt::runInEpochRooted(
[=] { proxy.broadcast<ColMsg, migrateToNext>(this_node); }
);

vt::runInEpochRooted([=] { proxy.broadcast<ColMsg, doWork>(this_node); });
vt::runInEpochRooted([=] { proxy.broadcast<doWork>(); });
vt::runInEpochRooted([=] { proxy.broadcast<migrateToNext>(); });
vt::runInEpochRooted([=] { proxy.broadcast<doWork>(); });
}

vt::finalize();
Expand Down
41 changes: 12 additions & 29 deletions examples/collection/polymorphic_collection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
/// [Polymorphic collection example]
static constexpr int32_t const default_num_elms = 16;
struct InitialConsTag{};
struct ColMsg;

struct Hello : vt::Collection<Hello, vt::Index1D> {
checkpoint_virtual_serialize_root()
Expand All @@ -62,19 +61,11 @@ struct Hello : vt::Collection<Hello, vt::Index1D> {
s | test_val;
}

virtual void doWork(ColMsg* msg);
virtual void doWork();

double test_val = 0.0;
};

struct ColMsg : vt::CollectionMessage<Hello> {
explicit ColMsg(vt::NodeType const& in_from_node)
: from_node(in_from_node)
{ }

vt::NodeType from_node = vt::uninitialized_destination;
};

template <typename T>
struct HelloTyped : Hello {
checkpoint_virtual_serialize_derived_from(Hello)
Expand All @@ -84,7 +75,7 @@ struct HelloTyped : Hello {
: Hello(checkpoint::SERIALIZE_CONSTRUCT_TAG{})
{}

virtual void doWork(ColMsg* msg) override;
virtual void doWork() override;

template <typename Serializer>
void serialize(Serializer& s) {
Expand All @@ -95,14 +86,14 @@ struct HelloTyped : Hello {
};

template <>
void HelloTyped<int>::doWork(ColMsg* msg) {
void HelloTyped<int>::doWork() {
fmt::print("correctly doing this -- int!\n");
Hello::doWork(msg);
Hello::doWork();
}

template <>
void HelloTyped<double>::doWork(ColMsg* msg) {
Hello::doWork(msg);
void HelloTyped<double>::doWork() {
Hello::doWork();
fmt::print("correctly doing this -- double!\n");
}

Expand All @@ -124,7 +115,7 @@ HelloTyped<double>::HelloTyped(InitialConsTag)
}
}

void Hello::doWork(ColMsg* msg) {
void Hello::doWork() {
vt_print(
gen, "idx={}: val={}, type={}\n",
getIndex(), test_val, typeid(*this).name()
Expand All @@ -147,7 +138,7 @@ void Hello::doWork(ColMsg* msg) {
}


static void migrateToNext(ColMsg* msg, Hello* col) {
static void migrateToNext(Hello* col) {
vt::NodeType this_node = vt::theContext()->getNode();
vt::NodeType num_nodes = vt::theContext()->getNumNodes();
vt::NodeType next_node = (this_node + 1) % num_nodes;
Expand Down Expand Up @@ -191,21 +182,13 @@ int main(int argc, char** argv) {
.wait();

for (int p = 0; p < 10; p++) {
vt::runInEpochCollective([&]{
proxy.broadcastCollective<ColMsg, &Hello::doWork>(this_node);
});
vt::runInEpochCollective([&]{
proxy.broadcastCollective<ColMsg, migrateToNext>(this_node);
});
vt::runInEpochCollective([&]{
proxy.broadcastCollective<ColMsg, &Hello::doWork>(this_node);
});
vt::runInEpochCollective([&]{ proxy.broadcastCollective<&Hello::doWork>(); });
vt::runInEpochCollective([&]{ proxy.broadcastCollective<migrateToNext>(); });
vt::runInEpochCollective([&]{ proxy.broadcastCollective<&Hello::doWork>(); });
}

for (int p = 0; p < 10; p++) {
vt::runInEpochCollective([&]{
proxy.broadcastCollective<ColMsg, &Hello::doWork>(this_node);
});
vt::runInEpochCollective([&]{ proxy.broadcastCollective<&Hello::doWork>(); });
vt::thePhase()->nextPhaseCollective();
}

Expand Down
4 changes: 2 additions & 2 deletions examples/collection/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ struct Block : vt::Collection<Block, vt::Index1D> {
);
auto const from_idx = getIndex().x();
auto data_msg = vt::makeMessage<DataMsg>(data_,from_idx);
vt::theMsg()->sendMsg<DataMsg,SubSolveInfo::solveDataIncoming>(
vt::theMsg()->sendMsg<SubSolveInfo::solveDataIncoming>(
requesting_node, data_msg
);
}
Expand Down Expand Up @@ -313,7 +313,7 @@ static void solveGroupSetup(vt::NodeType this_node, vt::VirtualProxyType coll_pr
if (this_node == 1) {
auto msg = vt::makeMessage<SubSolveMsg>(coll_proxy);
vt::envelopeSetGroup(msg->env, group_id);
vt::theMsg()->broadcastMsg<SubSolveMsg,SubSolveInfo::subSolveHandler>(msg);
vt::theMsg()->broadcastMsg<SubSolveInfo::subSolveHandler>(msg);
}
}, true
);
Expand Down
2 changes: 1 addition & 1 deletion examples/group/group_collective.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ int main(int argc, char** argv) {
if (this_node == 1) {
auto msg = vt::makeMessage<HelloGroupMsg>();
vt::envelopeSetGroup(msg->env, group);
vt::theMsg()->broadcastMsg<HelloGroupMsg, hello_group_handler>(msg);
vt::theMsg()->broadcastMsg<hello_group_handler>(msg);
}
}
);
Expand Down
4 changes: 2 additions & 2 deletions examples/group/group_rooted.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ int main(int argc, char** argv) {

if (this_node == 0) {
auto msg = vt::makeMessage<HelloMsg>(this_node);
vt::theMsg()->broadcastMsg<HelloMsg, hello_world>(msg);
vt::theMsg()->broadcastMsg<hello_world>(msg);

using RangeType = vt::group::region::Range;
auto list = std::make_unique<RangeType>(num_nodes / 2, num_nodes);

vt::theGroup()->newGroup(std::move(list), [=](vt::GroupType group){
auto gmsg = vt::makeMessage<HelloMsg>(this_node);
vt::envelopeSetGroup(gmsg->env, group);
vt::theMsg()->broadcastMsg<HelloMsg, hello_group_handler>(gmsg);
vt::theMsg()->broadcastMsg<hello_group_handler>(gmsg);
});
}

Expand Down
1 change: 0 additions & 1 deletion examples/hello_world/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ set(
hello_world_virtual_context
hello_world_virtual_context_remote
ring
param
objgroup
)

Expand Down
7 changes: 3 additions & 4 deletions examples/hello_world/hello_world.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ struct HelloMsg : vt::Message {
vt::NodeType from = 0;
};

static void hello_world(HelloMsg* msg) {
void hello_world(int a, int b, float c) {
vt::NodeType this_node = vt::theContext()->getNode();
fmt::print("{}: Hello from node {}\n", this_node, msg->from);
fmt::print("{}: Hello from node vals = {} {} {}\n", this_node, a, b, c);
}

int main(int argc, char** argv) {
Expand All @@ -65,8 +65,7 @@ int main(int argc, char** argv) {
}

if (this_node == 0) {
auto msg = vt::makeMessage<HelloMsg>(this_node);
vt::theMsg()->broadcastMsg<HelloMsg, hello_world>(msg);
vt::theMsg()->send<hello_world>(vt::Node{1}, 10, 20, 11.3f);
}

vt::finalize();
Expand Down
8 changes: 3 additions & 5 deletions examples/hello_world/hello_world_collection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@ struct Hello : vt::Collection<Hello, vt::Index1D> {
vtAssert(counter_ == 1, "Must be equal");
}

using TestMsg = vt::CollectionMessage<Hello>;

void doWork(TestMsg* msg) {
fmt::print("Hello from {}\n", this->getIndex());
void doWork(int val) {
fmt::print("Hello from {}: val={}\n", this->getIndex(), val);
counter_++;
}

Expand All @@ -79,7 +77,7 @@ int main(int argc, char** argv) {
.bounds(range)
.bulkInsert()
.wait();
proxy.broadcast<Hello::TestMsg,&Hello::doWork>();
proxy.broadcast<&Hello::doWork>(10);
}

vt::finalize();
Expand Down
10 changes: 5 additions & 5 deletions examples/hello_world/hello_world_collection_collective.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ struct Hello : vt::Collection<Hello, vt::Index1D> {
vtAssert(counter_ == num_nodes, "Should receive # nodes broadcasts");
}

using TestMsg = vt::CollectionMessage<Hello>;

void doWork(TestMsg* msg) {
void doWork(int val) {
counter_++;
fmt::print("Hello from {}, counter_={}\n", this->getIndex().x(), counter_);
fmt::print(
"Hello from {}, val={}, counter_={}\n", getIndex(), val, counter_
);
}

private:
Expand All @@ -78,7 +78,7 @@ int main(int argc, char** argv) {
.wait();

// All nodes send a broadcast to all elements
proxy.broadcast<Hello::TestMsg,&Hello::doWork>();
proxy.broadcast<&Hello::doWork>(29);

vt::finalize();

Expand Down
6 changes: 2 additions & 4 deletions examples/hello_world/hello_world_collection_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ struct Hello : vt::Collection<Hello, vt::Index1D> {
fmt::print("Reduce complete at {} value {}\n", this->getIndex(), msg->getVal());
}

using TestMsg = vt::CollectionMessage<Hello>;

void doWork(TestMsg* msg) {
void doWork() {
fmt::print("Hello from {}\n", this->getIndex());

// Get the proxy for the collection
Expand Down Expand Up @@ -85,7 +83,7 @@ int main(int argc, char** argv) {
.wait();

if (this_node == 0) {
proxy.broadcast<Hello::TestMsg,&Hello::doWork>();
proxy.broadcast<&Hello::doWork>();
}

vt::finalize();
Expand Down
6 changes: 2 additions & 4 deletions examples/hello_world/hello_world_collection_staged_insert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ struct Hello : vt::Collection<Hello, vt::Index1D> {
vtAssert(counter_ == 1, "Must be equal");
}

using TestMsg = vt::CollectionMessage<Hello>;

void doWork(TestMsg* msg) {
void doWork() {
counter_++;

vt::NodeType this_node = vt::theContext()->getNode();
Expand Down Expand Up @@ -107,7 +105,7 @@ int main(int argc, char** argv) {
.wait();

if (this_node == 1) {
proxy.broadcast<Hello::TestMsg,&Hello::doWork>();
proxy.broadcast<&Hello::doWork>();
}

vt::finalize();
Expand Down
Loading

0 comments on commit edf105d

Please sign in to comment.