58
58
namespace vt ::collective::reduce::allreduce {
59
59
60
60
template <typename DataT>
61
- struct AllreduceDblMsg
62
- : SerializeIfNeeded<vt::Message, AllreduceDblMsg<DataT>, DataT> {
63
- using MessageParentType =
64
- SerializeIfNeeded<::vt::Message, AllreduceDblMsg<DataT>, DataT>;
65
-
66
- AllreduceDblMsg () = default ;
67
- AllreduceDblMsg (AllreduceDblMsg const &) = default ;
68
- AllreduceDblMsg (AllreduceDblMsg&&) = default ;
69
-
70
- AllreduceDblMsg (DataT&& in_val, int step = 0 )
71
- : MessageParentType(),
72
- val_ (std::forward<DataT>(in_val)),
73
- step_(step) { }
74
- AllreduceDblMsg (DataT const & in_val, int step = 0 )
75
- : MessageParentType(),
76
- val_(in_val),
77
- step_(step) { }
78
-
79
- template <typename SerializeT>
80
- void serialize (SerializeT& s) {
81
- MessageParentType::serialize (s);
82
- s | val_;
83
- s | step_;
84
- }
85
-
86
- DataT val_ = {};
87
- int32_t step_ = {};
88
- };
89
-
90
- template <typename Scalar>
91
61
struct AllreduceDblRawMsg
92
62
: Message {
93
63
using MessageParentType = vt::Message;
@@ -99,34 +69,32 @@ struct AllreduceDblRawMsg
99
69
AllreduceDblRawMsg (AllreduceDblRawMsg&&) = default ;
100
70
~AllreduceDblRawMsg () {
101
71
if (owning_) {
102
- delete[] val_;
72
+ delete val_;
103
73
}
104
74
}
105
75
106
- AllreduceDblRawMsg (std::vector<Scalar> & in_val, int step = 0 )
76
+ AllreduceDblRawMsg (DataT const & in_val, size_t id , int step = 0 )
107
77
: MessageParentType(),
108
- val_ (in_val.data() ),
109
- size_(in_val.size() ),
78
+ val_ (& in_val),
79
+ id_(id ),
110
80
step_(step) { }
111
81
112
82
template <typename SerializeT>
113
83
void serialize (SerializeT& s) {
114
84
MessageParentType::serialize (s);
115
85
116
- s | size_;
117
-
118
86
if (s.isUnpacking ()) {
119
87
owning_ = true ;
120
- val_ = new Scalar[size_] ;
88
+ val_ = new DataT () ;
121
89
}
122
90
123
- checkpoint::dispatch::serializeArray (s, val_, size_) ;
124
-
91
+ s | *val_ ;
92
+ s | id_;
125
93
s | step_;
126
94
}
127
95
128
- Scalar * val_ = {};
129
- size_t size_ = {};
96
+ const DataT * val_ = {};
97
+ size_t id_ = {};
130
98
int32_t step_ = {};
131
99
bool owning_ = false ;
132
100
};
@@ -158,103 +126,126 @@ struct RecursiveDoubling {
158
126
* \param num_nodes The number of nodes.
159
127
* \param args Additional arguments for data initialization.
160
128
*/
129
+ template <typename ... Args>
161
130
RecursiveDoubling (
162
131
vt::objgroup::proxy::Proxy<ObjT> parentProxy, NodeType num_nodes,
163
- const DataT& data);
132
+ Args&&... data);
164
133
165
134
/* *
166
135
* \brief Start the allreduce operation.
167
136
*/
168
- void allreduce ();
137
+ void allreduce (size_t id );
169
138
170
139
/* *
171
140
* \brief Initialize the RecursiveDoubling object.
172
141
*
173
142
* \param args Additional arguments for data initialization.
174
143
*/
175
- void initialize (const DataT& data);
144
+ template <typename ... Args>
145
+ void initialize (size_t id, Args&&... data);
146
+ void initializeState (size_t id);
147
+
148
+ size_t generateNewId () { return id_++; }
176
149
177
150
/* *
178
151
* \brief Adjust for power of two nodes.
179
152
*/
180
- void adjustForPowerOfTwo ();
153
+ void adjustForPowerOfTwo (size_t id );
181
154
182
155
/* *
183
156
* \brief Handler for adjusting for power of two nodes.
184
157
*
185
158
* \param msg Pointer to the message.
186
159
*/
187
- void adjustForPowerOfTwoHandler (AllreduceDblRawMsg<Scalar >* msg);
160
+ void adjustForPowerOfTwoHandler (AllreduceDblRawMsg<DataT >* msg);
188
161
189
162
/* *
190
163
* \brief Check if the allreduce operation is done.
191
164
*
192
165
* \return True if the operation is done, otherwise false.
193
166
*/
194
- bool done ( );
167
+ bool isDone ( size_t id );
195
168
196
169
/* *
197
170
* \brief Check if the current state is valid for allreduce.
198
171
*
199
172
* \return True if the state is valid, otherwise false.
200
173
*/
201
- bool isValid ();
174
+ bool isValid (size_t id );
202
175
203
176
/* *
204
177
* \brief Check if all messages are received for the current step.
205
178
*
206
179
* \return True if all messages are received, otherwise false.
207
180
*/
208
- bool allMessagesReceived ();
181
+ bool allMessagesReceived (size_t id );
209
182
210
183
/* *
211
184
* \brief Check if the object is ready for the next step of allreduce.
212
185
*
213
186
* \return True if ready, otherwise false.
214
187
*/
215
- bool isReady ();
188
+ bool isReady (size_t id );
216
189
217
190
/* *
218
191
* \brief Perform the next step of the allreduce operation.
219
192
*/
220
- void reduceIter ();
193
+ void reduceIter (size_t id );
221
194
222
195
/* *
223
196
* \brief Try to reduce the message at the specified step.
224
197
*
225
198
* \param step The step at which to try reduction.
226
199
*/
227
- void tryReduce (int32_t step);
200
+ void tryReduce (size_t id, int32_t step);
228
201
229
202
/* *
230
203
* \brief Handler for the reduce iteration.
231
204
*
232
205
* \param msg Pointer to the message.
233
206
*/
234
- void reduceIterHandler (AllreduceDblRawMsg<Scalar >* msg);
207
+ void reduceIterHandler (AllreduceDblRawMsg<DataT >* msg);
235
208
236
209
/* *
237
210
* \brief Send data to excluded nodes for finalization.
238
211
*/
239
- void sendToExcludedNodes ();
212
+ void sendToExcludedNodes (size_t id );
240
213
241
214
/* *
242
215
* \brief Handler for sending data to excluded nodes.
243
216
*
244
217
* \param msg Pointer to the message.
245
218
*/
246
- void sendToExcludedNodesHandler (AllreduceDblRawMsg<Scalar >* msg);
219
+ void sendToExcludedNodesHandler (AllreduceDblRawMsg<DataT >* msg);
247
220
248
221
/* *
249
222
* \brief Perform the final part of the allreduce operation.
250
223
*/
251
- void finalPart ();
224
+ void finalPart (size_t id );
252
225
253
226
vt::objgroup::proxy::Proxy<RecursiveDoubling> proxy_ = {};
254
227
vt::objgroup::proxy::Proxy<ObjT> parent_proxy_ = {};
255
228
256
229
// DataT val_ = {};
257
- std::vector<Scalar> val_;
230
+
231
+ struct State {
232
+ DataT val_ = {};
233
+ bool finished_adjustment_part_ = false ;
234
+ MsgSharedPtr<AllreduceDblRawMsg<DataT>> adjust_message_ = nullptr ;
235
+
236
+ int32_t mask_ = 1 ;
237
+ int32_t step_ = 0 ;
238
+ bool initialized_ = false ;
239
+ bool completed_ = false ;
240
+
241
+ std::vector<bool > steps_recv_ = {};
242
+ std::vector<bool > steps_reduced_ = {};
243
+ std::vector<MsgSharedPtr<AllreduceDblRawMsg<DataT>>> messages_ = {};
244
+ };
245
+
246
+ size_t id_ = 0 ;
247
+ std::unordered_map<size_t , State> states_ = {};
248
+
258
249
NodeType num_nodes_ = {};
259
250
NodeType this_node_ = {};
260
251
@@ -265,17 +256,6 @@ struct RecursiveDoubling {
265
256
266
257
NodeType vrt_node_ = {};
267
258
bool is_part_of_adjustment_group_ = false ;
268
- bool finished_adjustment_part_ = false ;
269
-
270
- int32_t mask_ = 1 ;
271
- int32_t step_ = 0 ;
272
-
273
- bool completed_ = false ;
274
-
275
- std::vector<bool > steps_recv_ = {};
276
- std::vector<bool > steps_reduced_ = {};
277
-
278
- std::vector<MsgSharedPtr<AllreduceDblRawMsg<Scalar>>> messages_ = {};
279
259
};
280
260
281
261
} // namespace vt::collective::reduce::allreduce
0 commit comments