Skip to content

Commit 07c5eb7

Browse files
committed
merge allreduce
1 parent ba8b85f commit 07c5eb7

File tree

1 file changed

+4
-32
lines changed

1 file changed

+4
-32
lines changed

src/xccl/ProcessGroupXCCL.cpp

+4-32
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::scatter(
897897

898898
c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_impl(
899899
at::Tensor& tensor,
900+
const char* profilingTitle,
900901
const AllreduceOptions& opts) {
901902
return collective(
902903
tensor,
@@ -928,7 +929,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_impl(
928929
return;
929930
},
930931
OpType::ALLREDUCE,
931-
"xccl:all_reduce");
932+
profilingTitle);
932933
}
933934

934935
c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
@@ -956,36 +957,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
956957
-1, // globalRankStride
957958
size_); // worldSize
958959

959-
return collective(
960-
tensor,
961-
tensor,
962-
[&](at::Tensor& input,
963-
at::Tensor& output,
964-
xcclComm_t& comm,
965-
at::xpu::XPUStream& stream) {
966-
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
967-
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
968-
ccl::allreduce(
969-
input.data_ptr(),
970-
output.data_ptr(),
971-
(size_t)input.numel(),
972-
xcclDataType,
973-
xcclReduceOp,
974-
comm,
975-
ccl::create_stream(stream.queue()));
976-
#if !defined(XCCL_HAS_AVG)
977-
if (opts.reduceOp == ReduceOp::AVG) {
978-
auto divisor = getSize();
979-
c10::StreamGuard guard(stream);
980-
c10::xpu::XPUCachingAllocator::recordStream(
981-
output.storage().data_ptr(), stream);
982-
output.div_(divisor);
983-
}
984-
#endif
985-
return;
986-
},
987-
OpType::ALLREDUCE,
988-
"xccl:all_reduce");
960+
return allreduce_impl(tensor, "xccl:all_reduce", opts);
989961
}
990962

991963
c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
@@ -1621,7 +1593,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::barrier(const BarrierOptions& opts) {
16211593
at::Tensor barrierTensor =
16221594
at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat));
16231595

1624-
auto work = allreduce_impl(barrierTensor);
1596+
auto work = allreduce_impl(barrierTensor, "xccl:all_reduce_barrier");
16251597

16261598
auto xcclWork = dynamic_cast<ProcessGroupXCCL::WorkXCCL*>(work.get());
16271599
TORCH_CHECK(xcclWork);

0 commit comments

Comments
 (0)