@@ -42,6 +42,7 @@ c10::intrusive_ptr<Work> reduce_XPU(
42
42
const c10::intrusive_ptr<ReduceOp>& reduce_op,
43
43
int64_t root_rank,
44
44
int64_t root_tensor,
45
+ bool asyncOp,
45
46
int64_t timeout) {
46
47
auto tensor_vec = tensors.vec ();
47
48
return process_group->getBackend (c10::DeviceType::XPU)
@@ -51,7 +52,8 @@ c10::intrusive_ptr<Work> reduce_XPU(
51
52
*reduce_op.get (),
52
53
root_rank,
53
54
root_tensor,
54
- std::chrono::milliseconds (timeout)});
55
+ std::chrono::milliseconds (timeout),
56
+ asyncOp});
55
57
}
56
58
57
59
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> broadcast_XPU (
@@ -79,14 +81,16 @@ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> allreduce_XPU(
79
81
const c10::intrusive_ptr<ProcessGroup>& process_group,
80
82
const c10::intrusive_ptr<ReduceOp>& reduce_op,
81
83
const std::optional<at::Tensor>& sparse_indices,
84
+ bool asyncOp,
82
85
int64_t timeout) {
83
86
auto tensor_vec = tensors.vec ();
84
- auto work =
85
- process_group->getBackend (c10::DeviceType::XPU)
86
- ->allreduce (
87
- tensor_vec,
88
- AllreduceOptions{
89
- *reduce_op.get (), std::chrono::milliseconds (timeout)});
87
+ auto work = process_group->getBackend (c10::DeviceType::XPU)
88
+ ->allreduce (
89
+ tensor_vec,
90
+ AllreduceOptions{
91
+ *reduce_op.get (),
92
+ std::chrono::milliseconds (timeout),
93
+ asyncOp});
90
94
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
91
95
std::move (tensor_vec), work);
92
96
}
@@ -95,11 +99,13 @@ c10::intrusive_ptr<Work> allreduce_coalesced_XPU(
95
99
at::TensorList tensors,
96
100
const c10::intrusive_ptr<ProcessGroup>& process_group,
97
101
const c10::intrusive_ptr<ReduceOp>& reduce_op,
102
+ bool asyncOp,
98
103
int64_t timeout) {
99
104
auto tensor_vec = tensors.vec ();
100
105
AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{};
101
106
opts.reduceOp = *reduce_op.get ();
102
107
opts.timeout = std::chrono::milliseconds (timeout);
108
+ opts.asyncOp = asyncOp;
103
109
return process_group->getBackend (c10::DeviceType::XPU)
104
110
->allreduce_coalesced (tensor_vec, opts);
105
111
}
@@ -109,14 +115,15 @@ allgather_XPU(
109
115
const std::vector<std::vector<at::Tensor>>& output_tensors,
110
116
at::TensorList input_tensors,
111
117
const c10::intrusive_ptr<ProcessGroup>& process_group,
118
+ bool asyncOp,
112
119
int64_t timeout) {
113
120
auto input_tensors_vec = input_tensors.vec ();
114
121
auto work =
115
122
process_group->getBackend (c10::DeviceType::XPU)
116
123
->allgather (
117
124
const_cast <std::vector<std::vector<at::Tensor>>&>(output_tensors),
118
125
input_tensors_vec,
119
- AllgatherOptions{std::chrono::milliseconds (timeout)});
126
+ AllgatherOptions{std::chrono::milliseconds (timeout), asyncOp });
120
127
return std::
121
128
tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>(
122
129
output_tensors, work);
@@ -140,29 +147,37 @@ std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _allgather_base_XPU(
140
147
c10::intrusive_ptr<Work> allgather_coalesced_XPU (
141
148
const std::vector<std::vector<at::Tensor>>& output_lists,
142
149
const at::TensorList& input_list,
143
- const c10::intrusive_ptr<ProcessGroup>& process_group) {
150
+ const c10::intrusive_ptr<ProcessGroup>& process_group,
151
+ bool asyncOp) {
144
152
auto input_list_vec = input_list.vec ();
153
+ auto opts = AllgatherOptions{};
154
+ opts.asyncOp = asyncOp;
145
155
return process_group->getBackend (c10::DeviceType::XPU)
146
156
->allgather_coalesced (
147
157
const_cast <std::vector<std::vector<at::Tensor>>&>(output_lists),
148
- input_list_vec);
158
+ input_list_vec,
159
+ opts);
149
160
}
150
161
151
162
c10::intrusive_ptr<c10d::Work> allgather_into_tensor_coalesced_XPU (
152
163
at::TensorList outputs,
153
164
at::TensorList inputs,
154
- const c10::intrusive_ptr<ProcessGroup>& process_group) {
165
+ const c10::intrusive_ptr<ProcessGroup>& process_group,
166
+ bool asyncOp) {
155
167
auto output_vec = outputs.vec ();
156
168
auto input_vec = inputs.vec ();
169
+ auto opts = AllgatherOptions{};
170
+ opts.asyncOp = asyncOp;
157
171
return process_group->getBackend (c10::DeviceType::XPU)
158
- ->allgather_into_tensor_coalesced (output_vec, input_vec);
172
+ ->allgather_into_tensor_coalesced (output_vec, input_vec, opts );
159
173
}
160
174
161
175
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> reduce_scatter_XPU (
162
176
const at::TensorList& output_tensors,
163
177
const std::vector<std::vector<at::Tensor>>& input_tensors,
164
178
const c10::intrusive_ptr<ProcessGroup>& process_group,
165
179
const c10::intrusive_ptr<ReduceOp>& reduce_op,
180
+ bool asyncOp,
166
181
int64_t timeout) {
167
182
auto output_tensors_vec = output_tensors.vec ();
168
183
auto work =
@@ -171,7 +186,9 @@ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> reduce_scatter_XPU
171
186
output_tensors_vec,
172
187
const_cast <std::vector<std::vector<at::Tensor>>&>(input_tensors),
173
188
ReduceScatterOptions{
174
- *reduce_op.get (), std::chrono::milliseconds (timeout)});
189
+ *reduce_op.get (),
190
+ std::chrono::milliseconds (timeout),
191
+ asyncOp});
175
192
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
176
193
output_tensors_vec, work);
177
194
}
@@ -199,6 +216,7 @@ c10::intrusive_ptr<c10d::Work> reduce_scatter_tensor_coalesced_XPU(
199
216
at::TensorList inputs,
200
217
const c10::intrusive_ptr<ProcessGroup>& process_group,
201
218
const c10::intrusive_ptr<ReduceOp>& reduce_op,
219
+ bool asyncOp,
202
220
int64_t timeout) {
203
221
auto output_vec = outputs.vec ();
204
222
auto input_vec = inputs.vec ();
@@ -207,21 +225,23 @@ c10::intrusive_ptr<c10d::Work> reduce_scatter_tensor_coalesced_XPU(
207
225
output_vec,
208
226
input_vec,
209
227
ReduceScatterOptions{
210
- *reduce_op.get (), std::chrono::milliseconds (timeout)});
228
+ *reduce_op.get (), std::chrono::milliseconds (timeout), asyncOp });
211
229
}
212
230
213
231
c10::intrusive_ptr<Work> gather_XPU (
214
232
const std::vector<std::vector<at::Tensor>>& output_tensors,
215
233
const at::TensorList& input_tensors,
216
234
const c10::intrusive_ptr<ProcessGroup>& process_group,
217
235
int64_t root_rank,
236
+ bool asyncOp,
218
237
int64_t timeout) {
219
238
auto input_tensors_vec = input_tensors.vec ();
220
239
return process_group->getBackend (c10::DeviceType::XPU)
221
240
->gather (
222
241
const_cast <std::vector<std::vector<at::Tensor>>&>(output_tensors),
223
242
input_tensors_vec,
224
- GatherOptions{root_rank, std::chrono::milliseconds (timeout)});
243
+ GatherOptions{
244
+ root_rank, std::chrono::milliseconds (timeout), asyncOp});
225
245
}
226
246
227
247
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> scatter_XPU (
@@ -247,14 +267,16 @@ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> alltoall_XPU(
247
267
const at::TensorList& output_tensors,
248
268
const at::TensorList& input_tensors,
249
269
const c10::intrusive_ptr<ProcessGroup>& process_group,
270
+ bool asyncOp,
250
271
int64_t timeout) {
251
272
auto output_tensors_vec = output_tensors.vec ();
252
273
auto input_tensors_vec = input_tensors.vec ();
253
- auto work = process_group->getBackend (c10::DeviceType::XPU)
254
- ->alltoall (
255
- output_tensors_vec,
256
- input_tensors_vec,
257
- AllToAllOptions{std::chrono::milliseconds (timeout)});
274
+ auto work =
275
+ process_group->getBackend (c10::DeviceType::XPU)
276
+ ->alltoall (
277
+ output_tensors_vec,
278
+ input_tensors_vec,
279
+ AllToAllOptions{std::chrono::milliseconds (timeout), asyncOp});
258
280
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
259
281
std::move (output_tensors_vec), work);
260
282
}
@@ -265,23 +287,28 @@ c10::intrusive_ptr<Work> alltoall_base_XPU(
265
287
const c10::intrusive_ptr<ProcessGroup>& process_group,
266
288
std::vector<int64_t > output_split_sizes,
267
289
std::vector<int64_t > input_split_sizes,
290
+ bool asyncOp,
268
291
int64_t timeout) {
269
292
return process_group->getBackend (c10::DeviceType::XPU)
270
293
->alltoall_base (
271
294
output,
272
295
input,
273
296
output_split_sizes,
274
297
input_split_sizes,
275
- AllToAllOptions{std::chrono::milliseconds (timeout)});
298
+ AllToAllOptions{std::chrono::milliseconds (timeout), asyncOp });
276
299
}
277
300
278
301
c10::intrusive_ptr<Work> barrier_XPU (
279
302
at::Tensor /* unused */ ,
280
303
const c10::intrusive_ptr<ProcessGroup>& process_group,
281
304
const std::vector<int64_t >& device_ids,
305
+ bool asyncOp,
282
306
int64_t timeout) {
283
- return process_group->getBackend (c10::DeviceType::XPU)
284
- ->barrier (BarrierOptions{device_ids, std::chrono::milliseconds (timeout)});
307
+ auto opts = BarrierOptions{};
308
+ opts.device_ids = device_ids;
309
+ opts.timeout = std::chrono::milliseconds (timeout);
310
+ opts.asyncOp = asyncOp;
311
+ return process_group->getBackend (c10::DeviceType::XPU)->barrier (opts);
285
312
}
286
313
287
314
TORCH_LIBRARY_IMPL (c10d, XPU, m) {
0 commit comments