@@ -86,6 +86,50 @@ struct AllreduceRbnMsg
86
86
int32_t step_ = {};
87
87
};
88
88
89
+ template <typename Scalar>
90
+ struct AllreduceRbnRawMsg
91
+ : Message {
92
+ using MessageParentType = vt::Message;
93
+ vt_msg_serialize_required ();
94
+
95
+
96
+ AllreduceRbnRawMsg () = default ;
97
+ AllreduceRbnRawMsg (AllreduceRbnRawMsg const &) = default ;
98
+ AllreduceRbnRawMsg (AllreduceRbnRawMsg&&) = default ;
99
+ ~AllreduceRbnRawMsg () {
100
+ if (owning_) {
101
+ delete[] val_;
102
+ }
103
+ }
104
+
105
+ AllreduceRbnRawMsg (Scalar* in_val, size_t size, int step = 0 )
106
+ : MessageParentType(),
107
+ val_ (in_val),
108
+ size_(size),
109
+ step_(step) { }
110
+
111
+ template <typename SerializeT>
112
+ void serialize (SerializeT& s) {
113
+ MessageParentType::serialize (s);
114
+
115
+ s | size_;
116
+
117
+ if (s.isUnpacking ()) {
118
+ owning_ = true ;
119
+ val_ = new Scalar[size_];
120
+ }
121
+
122
+ checkpoint::dispatch::serializeArray (s, val_, size_);
123
+
124
+ s | step_;
125
+ }
126
+
127
+ Scalar* val_ = {};
128
+ size_t size_ = {};
129
+ int32_t step_ = {};
130
+ bool owning_ = false ;
131
+ };
132
+
89
133
/* *
90
134
* \struct Rabenseifner
91
135
* \brief Class implementing Rabenseifner's allreduce algorithm.
@@ -103,6 +147,7 @@ template <
103
147
>
104
148
struct Rabenseifner {
105
149
using DataType = DataHandler<DataT>;
150
+ using Scalar = typename DataType::Scalar;
106
151
107
152
/* *
108
153
* \brief Constructor for Rabenseifner's allreduce algorithm.
@@ -111,7 +156,7 @@ struct Rabenseifner {
111
156
* \param num_nodes Total number of nodes involved in the allreduce operation.
112
157
* \param args Additional arguments for initializing the data value.
113
158
*/
114
- template <typename ... Args>
159
+ template <typename ...Args>
115
160
Rabenseifner (
116
161
vt::objgroup::proxy::Proxy<ObjT> parentProxy, NodeType num_nodes,
117
162
Args&&... args);
@@ -123,7 +168,7 @@ struct Rabenseifner {
123
168
*
124
169
* \param args Additional arguments for initializing the data value.
125
170
*/
126
- template <typename ... Args>
171
+ template <typename ...Args>
127
172
void initialize (Args&&... args);
128
173
129
174
/* *
@@ -153,7 +198,7 @@ struct Rabenseifner {
153
198
*
154
199
* \param msg Message containing the data from the partner process.
155
200
*/
156
- void adjustForPowerOfTwoRightHalf (AllreduceRbnMsg<DataT >* msg);
201
+ void adjustForPowerOfTwoRightHalf (AllreduceRbnRawMsg<Scalar >* msg);
157
202
158
203
/* *
159
204
* \brief Handler for adjusting the left half of the process group.
@@ -162,7 +207,7 @@ struct Rabenseifner {
162
207
*
163
208
* \param msg Message containing the data from the partner process.
164
209
*/
165
- void adjustForPowerOfTwoLeftHalf (AllreduceRbnMsg<DataT >* msg);
210
+ void adjustForPowerOfTwoLeftHalf (AllreduceRbnRawMsg<Scalar >* msg);
166
211
167
212
/* *
168
213
* \brief Final adjustment step for non-power-of-two process counts.
@@ -171,7 +216,7 @@ struct Rabenseifner {
171
216
*
172
217
* \param msg Message containing the data from the partner process.
173
218
*/
174
- void adjustForPowerOfTwoFinalPart (AllreduceRbnMsg<DataT >* msg);
219
+ void adjustForPowerOfTwoFinalPart (AllreduceRbnRawMsg<Scalar >* msg);
175
220
176
221
/* *
177
222
* \brief Check if all scatter messages have been received.
@@ -215,7 +260,7 @@ struct Rabenseifner {
215
260
*
216
261
* \param msg Message containing the data from the partner process.
217
262
*/
218
- void scatterReduceIterHandler (AllreduceRbnMsg<DataT >* msg);
263
+ void scatterReduceIterHandler (AllreduceRbnRawMsg<Scalar >* msg);
219
264
220
265
/* *
221
266
* \brief Check if all gather messages have been received.
@@ -259,7 +304,7 @@ struct Rabenseifner {
259
304
*
260
305
* \param msg Message containing the data from the partner process.
261
306
*/
262
- void gatherIterHandler (AllreduceRbnMsg<DataT >* msg);
307
+ void gatherIterHandler (AllreduceRbnRawMsg<Scalar >* msg);
263
308
264
309
/* *
265
310
* \brief Perform the final part of the allreduce operation.
@@ -282,12 +327,13 @@ struct Rabenseifner {
282
327
*
283
328
* \param msg Message containing the final result.
284
329
*/
285
- void sendToExcludedNodesHandler (AllreduceRbnMsg<DataT >* msg);
330
+ void sendToExcludedNodesHandler (AllreduceRbnRawMsg<Scalar >* msg);
286
331
287
332
vt::objgroup::proxy::Proxy<Rabenseifner> proxy_ = {};
288
333
vt::objgroup::proxy::Proxy<ObjT> parent_proxy_ = {};
289
334
290
- DataT val_ = {};
335
+ // DataT val_ = {};
336
+ std::vector<Scalar> val_;
291
337
size_t size_ = {};
292
338
NodeType num_nodes_ = {};
293
339
NodeType this_node_ = {};
@@ -297,10 +343,10 @@ struct Rabenseifner {
297
343
int32_t nprocs_pof2_ = {};
298
344
int32_t nprocs_rem_ = {};
299
345
300
- std::vector<int32_t > r_index_ = {};
301
- std::vector<int32_t > r_count_ = {};
302
- std::vector<int32_t > s_index_ = {};
303
- std::vector<int32_t > s_count_ = {};
346
+ std::vector<uint32_t > r_index_ = {};
347
+ std::vector<uint32_t > r_count_ = {};
348
+ std::vector<uint32_t > s_index_ = {};
349
+ std::vector<uint32_t > s_count_ = {};
304
350
305
351
NodeType vrt_node_ = {};
306
352
bool is_part_of_adjustment_group_ = false ;
@@ -314,7 +360,7 @@ struct Rabenseifner {
314
360
int32_t scatter_num_recv_ = 0 ;
315
361
std::vector<bool > scatter_steps_recv_ = {};
316
362
std::vector<bool > scatter_steps_reduced_ = {};
317
- std::vector<MsgSharedPtr<AllreduceRbnMsg<DataT >>> scatter_messages_ = {};
363
+ std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar >>> scatter_messages_ = {};
318
364
bool finished_scatter_part_ = false ;
319
365
320
366
// Gather
@@ -323,7 +369,7 @@ struct Rabenseifner {
323
369
int32_t gather_num_recv_ = 0 ;
324
370
std::vector<bool > gather_steps_recv_ = {};
325
371
std::vector<bool > gather_steps_reduced_ = {};
326
- std::vector<MsgSharedPtr<AllreduceRbnMsg<DataT >>> gather_messages_ = {};
372
+ std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar >>> gather_messages_ = {};
327
373
};
328
374
329
375
} // namespace vt::collective::reduce::allreduce
0 commit comments