Skip to content

Commit 0fdb0ab

Browse files
committed
Update xccl backend register align with latest torch
1 parent 56e7dda commit 0fdb0ab

File tree

1 file changed

+50
-23
lines changed

1 file changed

+50
-23
lines changed

src/xccl/Register.cpp

+50-23
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ c10::intrusive_ptr<Work> reduce_XPU(
4242
const c10::intrusive_ptr<ReduceOp>& reduce_op,
4343
int64_t root_rank,
4444
int64_t root_tensor,
45+
bool asyncOp,
4546
int64_t timeout) {
4647
auto tensor_vec = tensors.vec();
4748
return process_group->getBackend(c10::DeviceType::XPU)
@@ -51,7 +52,8 @@ c10::intrusive_ptr<Work> reduce_XPU(
5152
*reduce_op.get(),
5253
root_rank,
5354
root_tensor,
54-
std::chrono::milliseconds(timeout)});
55+
std::chrono::milliseconds(timeout),
56+
asyncOp});
5557
}
5658

5759
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> broadcast_XPU(
@@ -79,14 +81,16 @@ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> allreduce_XPU(
7981
const c10::intrusive_ptr<ProcessGroup>& process_group,
8082
const c10::intrusive_ptr<ReduceOp>& reduce_op,
8183
const std::optional<at::Tensor>& sparse_indices,
84+
bool asyncOp,
8285
int64_t timeout) {
8386
auto tensor_vec = tensors.vec();
84-
auto work =
85-
process_group->getBackend(c10::DeviceType::XPU)
86-
->allreduce(
87-
tensor_vec,
88-
AllreduceOptions{
89-
*reduce_op.get(), std::chrono::milliseconds(timeout)});
87+
auto work = process_group->getBackend(c10::DeviceType::XPU)
88+
->allreduce(
89+
tensor_vec,
90+
AllreduceOptions{
91+
*reduce_op.get(),
92+
std::chrono::milliseconds(timeout),
93+
asyncOp});
9094
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
9195
std::move(tensor_vec), work);
9296
}
@@ -95,11 +99,13 @@ c10::intrusive_ptr<Work> allreduce_coalesced_XPU(
9599
at::TensorList tensors,
96100
const c10::intrusive_ptr<ProcessGroup>& process_group,
97101
const c10::intrusive_ptr<ReduceOp>& reduce_op,
102+
bool asyncOp,
98103
int64_t timeout) {
99104
auto tensor_vec = tensors.vec();
100105
AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{};
101106
opts.reduceOp = *reduce_op.get();
102107
opts.timeout = std::chrono::milliseconds(timeout);
108+
opts.asyncOp = asyncOp;
103109
return process_group->getBackend(c10::DeviceType::XPU)
104110
->allreduce_coalesced(tensor_vec, opts);
105111
}
@@ -109,14 +115,15 @@ allgather_XPU(
109115
const std::vector<std::vector<at::Tensor>>& output_tensors,
110116
at::TensorList input_tensors,
111117
const c10::intrusive_ptr<ProcessGroup>& process_group,
118+
bool asyncOp,
112119
int64_t timeout) {
113120
auto input_tensors_vec = input_tensors.vec();
114121
auto work =
115122
process_group->getBackend(c10::DeviceType::XPU)
116123
->allgather(
117124
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
118125
input_tensors_vec,
119-
AllgatherOptions{std::chrono::milliseconds(timeout)});
126+
AllgatherOptions{std::chrono::milliseconds(timeout), asyncOp});
120127
return std::
121128
tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>(
122129
output_tensors, work);
@@ -140,29 +147,37 @@ std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _allgather_base_XPU(
140147
c10::intrusive_ptr<Work> allgather_coalesced_XPU(
141148
const std::vector<std::vector<at::Tensor>>& output_lists,
142149
const at::TensorList& input_list,
143-
const c10::intrusive_ptr<ProcessGroup>& process_group) {
150+
const c10::intrusive_ptr<ProcessGroup>& process_group,
151+
bool asyncOp) {
144152
auto input_list_vec = input_list.vec();
153+
auto opts = AllgatherOptions{};
154+
opts.asyncOp = asyncOp;
145155
return process_group->getBackend(c10::DeviceType::XPU)
146156
->allgather_coalesced(
147157
const_cast<std::vector<std::vector<at::Tensor>>&>(output_lists),
148-
input_list_vec);
158+
input_list_vec,
159+
opts);
149160
}
150161

151162
c10::intrusive_ptr<c10d::Work> allgather_into_tensor_coalesced_XPU(
152163
at::TensorList outputs,
153164
at::TensorList inputs,
154-
const c10::intrusive_ptr<ProcessGroup>& process_group) {
165+
const c10::intrusive_ptr<ProcessGroup>& process_group,
166+
bool asyncOp) {
155167
auto output_vec = outputs.vec();
156168
auto input_vec = inputs.vec();
169+
auto opts = AllgatherOptions{};
170+
opts.asyncOp = asyncOp;
157171
return process_group->getBackend(c10::DeviceType::XPU)
158-
->allgather_into_tensor_coalesced(output_vec, input_vec);
172+
->allgather_into_tensor_coalesced(output_vec, input_vec, opts);
159173
}
160174

161175
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> reduce_scatter_XPU(
162176
const at::TensorList& output_tensors,
163177
const std::vector<std::vector<at::Tensor>>& input_tensors,
164178
const c10::intrusive_ptr<ProcessGroup>& process_group,
165179
const c10::intrusive_ptr<ReduceOp>& reduce_op,
180+
bool asyncOp,
166181
int64_t timeout) {
167182
auto output_tensors_vec = output_tensors.vec();
168183
auto work =
@@ -171,7 +186,9 @@ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> reduce_scatter_XPU
171186
output_tensors_vec,
172187
const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors),
173188
ReduceScatterOptions{
174-
*reduce_op.get(), std::chrono::milliseconds(timeout)});
189+
*reduce_op.get(),
190+
std::chrono::milliseconds(timeout),
191+
asyncOp});
175192
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
176193
output_tensors_vec, work);
177194
}
@@ -199,6 +216,7 @@ c10::intrusive_ptr<c10d::Work> reduce_scatter_tensor_coalesced_XPU(
199216
at::TensorList inputs,
200217
const c10::intrusive_ptr<ProcessGroup>& process_group,
201218
const c10::intrusive_ptr<ReduceOp>& reduce_op,
219+
bool asyncOp,
202220
int64_t timeout) {
203221
auto output_vec = outputs.vec();
204222
auto input_vec = inputs.vec();
@@ -207,21 +225,23 @@ c10::intrusive_ptr<c10d::Work> reduce_scatter_tensor_coalesced_XPU(
207225
output_vec,
208226
input_vec,
209227
ReduceScatterOptions{
210-
*reduce_op.get(), std::chrono::milliseconds(timeout)});
228+
*reduce_op.get(), std::chrono::milliseconds(timeout), asyncOp});
211229
}
212230

213231
c10::intrusive_ptr<Work> gather_XPU(
214232
const std::vector<std::vector<at::Tensor>>& output_tensors,
215233
const at::TensorList& input_tensors,
216234
const c10::intrusive_ptr<ProcessGroup>& process_group,
217235
int64_t root_rank,
236+
bool asyncOp,
218237
int64_t timeout) {
219238
auto input_tensors_vec = input_tensors.vec();
220239
return process_group->getBackend(c10::DeviceType::XPU)
221240
->gather(
222241
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
223242
input_tensors_vec,
224-
GatherOptions{root_rank, std::chrono::milliseconds(timeout)});
243+
GatherOptions{
244+
root_rank, std::chrono::milliseconds(timeout), asyncOp});
225245
}
226246

227247
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> scatter_XPU(
@@ -247,14 +267,16 @@ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> alltoall_XPU(
247267
const at::TensorList& output_tensors,
248268
const at::TensorList& input_tensors,
249269
const c10::intrusive_ptr<ProcessGroup>& process_group,
270+
bool asyncOp,
250271
int64_t timeout) {
251272
auto output_tensors_vec = output_tensors.vec();
252273
auto input_tensors_vec = input_tensors.vec();
253-
auto work = process_group->getBackend(c10::DeviceType::XPU)
254-
->alltoall(
255-
output_tensors_vec,
256-
input_tensors_vec,
257-
AllToAllOptions{std::chrono::milliseconds(timeout)});
274+
auto work =
275+
process_group->getBackend(c10::DeviceType::XPU)
276+
->alltoall(
277+
output_tensors_vec,
278+
input_tensors_vec,
279+
AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp});
258280
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
259281
std::move(output_tensors_vec), work);
260282
}
@@ -265,23 +287,28 @@ c10::intrusive_ptr<Work> alltoall_base_XPU(
265287
const c10::intrusive_ptr<ProcessGroup>& process_group,
266288
std::vector<int64_t> output_split_sizes,
267289
std::vector<int64_t> input_split_sizes,
290+
bool asyncOp,
268291
int64_t timeout) {
269292
return process_group->getBackend(c10::DeviceType::XPU)
270293
->alltoall_base(
271294
output,
272295
input,
273296
output_split_sizes,
274297
input_split_sizes,
275-
AllToAllOptions{std::chrono::milliseconds(timeout)});
298+
AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp});
276299
}
277300

278301
c10::intrusive_ptr<Work> barrier_XPU(
279302
at::Tensor /* unused */,
280303
const c10::intrusive_ptr<ProcessGroup>& process_group,
281304
const std::vector<int64_t>& device_ids,
305+
bool asyncOp,
282306
int64_t timeout) {
283-
return process_group->getBackend(c10::DeviceType::XPU)
284-
->barrier(BarrierOptions{device_ids, std::chrono::milliseconds(timeout)});
307+
auto opts = BarrierOptions{};
308+
opts.device_ids = device_ids;
309+
opts.timeout = std::chrono::milliseconds(timeout);
310+
opts.asyncOp = asyncOp;
311+
return process_group->getBackend(c10::DeviceType::XPU)->barrier(opts);
285312
}
286313

287314
TORCH_LIBRARY_IMPL(c10d, XPU, m) {

0 commit comments

Comments
 (0)