Skip to content

Commit 0db918f

Browse files
authored
Merge pull request #2136 from DARMA-tasking/276-automatically-synthesize-message-types-for-sends-to-arbitrary-handlers-cb
276 Callback parameterization
2 parents 6109deb + 1012272 commit 0db918f

File tree

95 files changed

+2077
-1619
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+2077
-1619
lines changed

examples/callback/callback.cc

+5-5
Original file line numberDiff line numberDiff line change
@@ -141,19 +141,19 @@ int main(int argc, char** argv) {
141141
auto cb_functor = vt::theCB()->makeSend<CallbackFunctor>(dest);
142142
bounceCallback(cb_functor);
143143

144-
auto cb_func = vt::theCB()->makeSend<TestMsg,callbackFunc>(dest);
144+
auto cb_func = vt::theCB()->makeSend<callbackFunc>(dest);
145145
bounceCallback(cb_func);
146146

147-
auto cb_obj = vt::theCB()->makeSend<MyObj,TestMsg,&MyObj::handler>(obj[dest]);
147+
auto cb_obj = vt::theCB()->makeSend<&MyObj::handler>(obj[dest]);
148148
bounceCallback(cb_obj);
149149

150-
auto cb_obj_bcast = vt::theCB()->makeBcast<MyObj,TestMsg,&MyObj::handler>(obj);
150+
auto cb_obj_bcast = vt::theCB()->makeBcast<&MyObj::handler>(obj);
151151
bounceCallback(cb_obj_bcast);
152152

153-
auto cb_col = vt::theCB()->makeSend<MyCol,TestMsg,colHan>(col[5]);
153+
auto cb_col = vt::theCB()->makeSend<colHan>(col[5]);
154154
bounceCallback(cb_col);
155155

156-
auto cb_col_bcast = vt::theCB()->makeBcast<MyCol,TestMsg,colHan>(col);
156+
auto cb_col_bcast = vt::theCB()->makeBcast<colHan>(col);
157157
bounceCallback(cb_col_bcast);
158158
}
159159

examples/collection/jacobi1d_vt.cc

+6-17
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ static constexpr double const default_tol = 1.0e-02;
7373

7474
struct NodeObj {
7575
bool is_finished_ = false;
76-
struct WorkFinishedMsg : vt::Message {};
77-
78-
void workFinishedHandler(WorkFinishedMsg*) { is_finished_ = true; }
76+
void workFinishedHandler() { is_finished_ = true; }
7977
bool isWorkFinished() { return is_finished_; }
8078
};
8179
using NodeObjProxy = vt::objgroup::proxy::Proxy<NodeObj>;
@@ -119,19 +117,12 @@ struct LinearPb1DJacobi : vt::Collection<LinearPb1DJacobi,vt::Index1D> {
119117

120118
};
121119

122-
struct ReduxMsg : vt::collective::ReduceTMsg<double> {
123-
ReduxMsg() = default;
124-
explicit ReduxMsg(double in_val) : ReduceTMsg<double>(in_val) { }
125-
};
126-
127-
void checkCompleteCB(ReduxMsg* msg) {
120+
void checkCompleteCB(double normRes) {
128121
//
129122
// Only one object for the reduction will visit
130123
// this function
131124
//
132125

133-
double normRes = msg->getConstVal();
134-
135126
auto const iter_max_reached = iter_ > maxIter_;
136127
auto const norm_res_done = normRes < default_tol;
137128

@@ -143,7 +134,7 @@ struct LinearPb1DJacobi : vt::Collection<LinearPb1DJacobi,vt::Index1D> {
143134
fmt::print(to_print);
144135

145136
// Notify all nodes that computation is finished
146-
objProxy_.broadcast<NodeObj::WorkFinishedMsg, &NodeObj::workFinishedHandler>();
137+
objProxy_.broadcast<&NodeObj::workFinishedHandler>();
147138
} else {
148139
fmt::print(" ## ITER {} >> Residual Norm = {} \n", iter_, normRes);
149140
}
@@ -184,11 +175,9 @@ struct LinearPb1DJacobi : vt::Collection<LinearPb1DJacobi,vt::Index1D> {
184175
}
185176

186177
auto proxy = this->getCollectionProxy();
187-
auto cb = vt::theCB()->makeSend<
188-
LinearPb1DJacobi,ReduxMsg,&LinearPb1DJacobi::checkCompleteCB
189-
>(proxy[0]);
190-
auto msg2 = vt::makeMessage<ReduxMsg>(maxNorm);
191-
proxy.reduce<vt::collective::MaxOp<double>>(msg2.get(),cb);
178+
proxy.reduce<&LinearPb1DJacobi::checkCompleteCB, vt::collective::MaxOp>(
179+
proxy[0], maxNorm
180+
);
192181

193182
}
194183

examples/collection/jacobi2d_vt.cc

+6-17
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ static constexpr double const default_tol = 1.0e-02;
8080

8181
struct NodeObj {
8282
bool is_finished_ = false;
83-
struct WorkFinishedMsg : vt::Message {};
84-
85-
void workFinishedHandler(WorkFinishedMsg*) { is_finished_ = true; }
83+
void workFinishedHandler() { is_finished_ = true; }
8684
bool isWorkFinished() { return is_finished_; }
8785
};
8886
using NodeObjProxy = vt::objgroup::proxy::Proxy<NodeObj>;
@@ -129,19 +127,12 @@ struct LinearPb2DJacobi : vt::Collection<LinearPb2DJacobi,vt::Index2D> {
129127
};
130128

131129

132-
struct ReduxMsg : vt::collective::ReduceTMsg<double> {
133-
ReduxMsg() = default;
134-
explicit ReduxMsg(double in_val) : ReduceTMsg<double>(in_val) { }
135-
};
136-
137-
138-
void checkCompleteCB(ReduxMsg* msg) {
130+
void checkCompleteCB(double const normRes) {
139131
//
140132
// Only one object for the reduction will visit
141133
// this function
142134
//
143135

144-
const double normRes = msg->getConstVal();
145136
auto const iter_max_reached = iter_ > maxIter_;
146137
auto const norm_res_done = normRes < default_tol;
147138

@@ -153,7 +144,7 @@ struct LinearPb2DJacobi : vt::Collection<LinearPb2DJacobi,vt::Index2D> {
153144
fmt::print(to_print);
154145

155146
// Notify all nodes that computation is finished
156-
objProxy_.broadcast<NodeObj::WorkFinishedMsg, &NodeObj::workFinishedHandler>();
147+
objProxy_.broadcast<&NodeObj::workFinishedHandler>();
157148
} else {
158149
fmt::print(" ## ITER {} >> Residual Norm = {} \n", iter_, normRes);
159150
}
@@ -222,11 +213,9 @@ struct LinearPb2DJacobi : vt::Collection<LinearPb2DJacobi,vt::Index2D> {
222213
}
223214

224215
auto proxy = this->getCollectionProxy();
225-
auto cb = vt::theCB()->makeSend<
226-
LinearPb2DJacobi,ReduxMsg,&LinearPb2DJacobi::checkCompleteCB
227-
>(proxy(0,0));
228-
auto msg2 = vt::makeMessage<ReduxMsg>(maxNorm);
229-
proxy.reduce<vt::collective::MaxOp<double>>(msg2.get(),cb);
216+
proxy.reduce<&LinearPb2DJacobi::checkCompleteCB, vt::collective::MaxOp>(
217+
proxy(0,0), maxNorm
218+
);
230219
}
231220

232221
struct VecMsg : vt::CollectionMessage<LinearPb2DJacobi> {

examples/collection/reduce_integral.cc

+14-18
Original file line numberDiff line numberDiff line change
@@ -101,21 +101,17 @@ struct Integration1D : vt::Collection<Integration1D, vt::Index1D> {
101101
numPartsPerObject_(default_nparts_object)
102102
{ }
103103

104-
struct CheckIntegral {
105-
106-
void operator()(ReduceMsg* msg) {
107-
fmt::print(" >> The integral over [0, 1] is {}\n", msg->getConstVal());
108-
fmt::print(
109-
" >> The absolute error is {}\n",
110-
std::fabs(msg->getConstVal() - exactIntegral)
111-
);
112-
//
113-
// Set the 'root_reduce_finished' variable to true.
114-
//
115-
root_reduce_finished = true;
116-
}
117-
118-
};
104+
void checkIntegral(double val) {
105+
fmt::print(" >> The integral over [0, 1] is {}\n", val);
106+
fmt::print(
107+
" >> The absolute error is {}\n",
108+
std::fabs(val - exactIntegral)
109+
);
110+
//
111+
// Set the 'root_reduce_finished' variable to true.
112+
//
113+
root_reduce_finished = true;
114+
}
119115

120116
struct InitMsg : vt::CollectionMessage<Integration1D> {
121117

@@ -177,9 +173,9 @@ struct Integration1D : vt::Collection<Integration1D, vt::Index1D> {
177173
//
178174

179175
auto proxy = this->getCollectionProxy();
180-
auto msgCB = vt::makeMessage<ReduceMsg>(quadsum);
181-
auto cback = vt::theCB()->makeSend<CheckIntegral>(reduce_root_node);
182-
proxy.reduce<vt::collective::PlusOp<double>>(msgCB.get(),cback);
176+
proxy.reduce<&Integration1D::checkIntegral, vt::collective::PlusOp>(
177+
proxy[0], quadsum
178+
);
183179
}
184180

185181
};

examples/collection/transpose.cc

+2-6
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ struct RequestDataMsg : vt::CollectionMessage<ColT> {
5858
vt::NodeType node_;
5959
};
6060

61-
struct InitMsg : vt::collective::ReduceNoneMsg { };
62-
6361
struct DataMsg : vt::Message {
6462
using MessageParentType = vt::Message;
6563
vt_msg_serialize_required(); // by payload_
@@ -170,7 +168,7 @@ struct Block : vt::Collection<Block, vt::Index1D> {
170168
);
171169
}
172170

173-
void doneInit(InitMsg* msg) {
171+
void doneInit() {
174172
if (getIndex().x() == 0) {
175173
auto proxy = this->getCollectionProxy();
176174
auto proxy_msg = vt::makeMessage<ProxyMsg>(proxy.getProxy());
@@ -183,9 +181,7 @@ struct Block : vt::Collection<Block, vt::Index1D> {
183181
initialize();
184182
// Wait for all initializations to complete
185183
auto proxy = this->getCollectionProxy();
186-
auto cb = vt::theCB()->makeBcast<Block, InitMsg, &Block::doneInit>(proxy);
187-
auto empty = vt::makeMessage<InitMsg>();
188-
proxy.reduce(empty.get(), cb);
184+
proxy.allreduce<&Block::doneInit>();
189185
}
190186

191187
private:

examples/hello_world/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ set(
1111
hello_world_virtual_context_remote
1212
ring
1313
objgroup
14+
hello_reduce
1415
)
1516

1617
foreach(EXAMPLE_NAME ${HELLO_WORLD_EXAMPLES})

examples/hello_world/hello_reduce.cc

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
//@HEADER
3+
// *****************************************************************************
4+
//
5+
// hello_reduce.cc
6+
// DARMA/vt => Virtual Transport
7+
//
8+
// Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC
9+
// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S.
10+
// Government retains certain rights in this software.
11+
//
12+
// Redistribution and use in source and binary forms, with or without
13+
// modification, are permitted provided that the following conditions are met:
14+
//
15+
// * Redistributions of source code must retain the above copyright notice,
16+
// this list of conditions and the following disclaimer.
17+
//
18+
// * Redistributions in binary form must reproduce the above copyright notice,
19+
// this list of conditions and the following disclaimer in the documentation
20+
// and/or other materials provided with the distribution.
21+
//
22+
// * Neither the name of the copyright holder nor the names of its
23+
// contributors may be used to endorse or promote products derived from this
24+
// software without specific prior written permission.
25+
//
26+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
27+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
29+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
30+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
31+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
32+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
33+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
34+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
35+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
36+
// POSSIBILITY OF SUCH DAMAGE.
37+
//
38+
// Questions? Contact darma@sandia.gov
39+
//
40+
// *****************************************************************************
41+
//@HEADER
42+
*/
43+
44+
#include <vt/transport.h>
45+
46+
void reduceResult(int result, double result2) {
47+
auto num_nodes = vt::theContext()->getNumNodes();
48+
fmt::print("reduction value={}, {}\n", result, result2);
49+
vtAssert(num_nodes * 50 == result, "Must be equal");
50+
}
51+
52+
int main(int argc, char** argv) {
53+
vt::initialize(argc, argv);
54+
55+
vt::NodeType const root = 0;
56+
57+
auto r = vt::theCollective()->global();
58+
r->reduce<reduceResult, vt::collective::PlusOp>(vt::Node{root}, 50, 52.334);
59+
60+
vt::finalize();
61+
return 0;
62+
}

examples/hello_world/hello_world_collection_reduce.cc

+7-12
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,19 @@
4545

4646
/// [Hello world reduce collection]
4747
struct Hello : vt::Collection<Hello, vt::Index1D> {
48-
using ReduceMsg = vt::collective::ReduceTMsg<int>;
49-
50-
void done(ReduceMsg* msg) {
51-
fmt::print("Reduce complete at {} value {}\n", this->getIndex(), msg->getVal());
48+
void done(int val, double val2) {
49+
fmt::print("Reduce complete at {} values {} {}\n", getIndex(), val, val2);
5250
}
5351

5452
void doWork() {
55-
fmt::print("Hello from {}\n", this->getIndex());
53+
fmt::print("Hello from {}\n", getIndex());
5654

5755
// Get the proxy for the collection
58-
auto proxy = this->getCollectionProxy();
59-
60-
// Create a callback for when the reduction finishes
61-
auto cb = vt::theCB()->makeSend<Hello,ReduceMsg,&Hello::done>(proxy(2));
56+
auto proxy = getCollectionProxy();
6257

63-
// Create and send the reduction message holding an int
64-
auto red_msg = vt::makeMessage<ReduceMsg>(this->getIndex().x());
65-
proxy.reduce<vt::collective::PlusOp<int>>(red_msg.get(),cb);
58+
auto val = getIndex().x();
59+
auto val2 = 2.4;
60+
proxy.allreduce<&Hello::done, vt::collective::PlusOp>(val, val2);
6661
}
6762
};
6863

src/vt/collective/collective_scope.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ TagType CollectiveScope::mpiCollectiveAsync(ActionType action) {
8383
NodeType collective_root = 0;
8484

8585
using CollectiveMsg = CollectiveAlg::CollectiveMsg;
86-
auto cb = theCB()->makeBcast<CollectiveMsg,&CollectiveAlg::runCollective>();
86+
auto cb = theCB()->makeBcast<&CollectiveAlg::runCollective>();
8787
auto msg = makeMessage<CollectiveMsg>(is_user_tag_, scope_, tag, collective_root);
8888

8989
// The tag for the reduce is a combination of the scope and seq tag.

0 commit comments

Comments
 (0)