From dec4b5aa637b8c0a22ac7c5c77c4710fa68de7c1 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 10 Mar 2025 21:25:58 +0800 Subject: [PATCH 1/4] fix singlep2p cclstream mapping --- src/xccl/ProcessGroupXCCL.cpp | 66 +++++++++++++++++++---------------- src/xccl/ProcessGroupXCCL.hpp | 4 ++- 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 5d16cab07..74d0ff4a3 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -352,12 +352,11 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( ccl::group_start(); } - ccl::stream xccl_stream = ccl::create_stream(q); + auto xccl_stream = std::make_shared(ccl::create_stream(q)); std::lock_guard lock(mutex_); devXCCLCommMap_.emplace(deviceKey, XCCLComm); xcclStreamsMap_.emplace( - deviceKey, - std::make_pair(at::xpu::XPUStream(stream), std::move(xccl_stream))); + deviceKey, std::make_pair(at::xpu::XPUStream(stream), xccl_stream)); xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); return XCCLComm; @@ -600,15 +599,19 @@ c10::intrusive_ptr ProcessGroupXCCL::send( at::xpu::XPUStream& stream, int dst) { auto xcclDataType = getXcclDataType(input.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(tensor.device().index())).second; + std::shared_ptr xcclStream = nullptr; + for (const auto& pair : xcclStreamsMap_) { + if (pair.second.first == stream) { + xcclStream = pair.second.second; + } + } ccl::send( input.data_ptr(), (size_t)input.numel(), xcclDataType, dst, comm, - xcclStream); + *xcclStream); return; }, dstRank, @@ -650,15 +653,19 @@ c10::intrusive_ptr ProcessGroupXCCL::recv( at::xpu::XPUStream& stream, int src) { auto xcclDataType = getXcclDataType(output.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(tensor.device().index())).second; + std::shared_ptr xcclStream = nullptr; + for (const auto& pair : xcclStreamsMap_) { + if (pair.second.first == stream) { + xcclStream = pair.second.second; + } + } ccl::recv( output.data_ptr(), (size_t)output.numel(), xcclDataType, src, comm, - xcclStream); + *xcclStream); return; }, srcRank, @@ -759,7 +766,7 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( xcclDataType, r, comm, - xcclStream); + *xcclStream); } else { // on its own rank, simply copy from the input outputs[r].copy_(inputTensor); @@ -773,7 +780,7 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( xcclDataType, root, comm, - xcclStream); + *xcclStream); } return; } @@ -875,7 +882,7 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( send_type, r, comm, - xcclStream); + *xcclStream); } else { // on its own rank, simply copy from the input outputTensor.copy_(inputs[r]); @@ -891,7 +898,7 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( recv_type, root, comm, - xcclStream); + *xcclStream); } return; @@ -921,7 +928,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( xcclDataType, xcclReduceOp, comm, - xcclStream); + *xcclStream); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG) { @@ -977,7 +984,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( xcclDataType, xcclReduceOp, comm, - xcclStream); + *xcclStream); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG) { @@ -1035,7 +1042,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( xcclDataType, xcclReduceOp, comm, - xcclStream); + *xcclStream); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG) { @@ -1094,7 +1101,7 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( xcclDataType, root, comm, - xcclStream); + *xcclStream); return; }, OpType::BROADCAST, @@ -1129,7 +1136,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( xcclDataType, root, comm, - xcclStream); + *xcclStream); return; }, OpType::BROADCAST, @@ -1180,7 +1187,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( xcclReduceOp, root, comm, - xcclStream); + *xcclStream); // WA due to oneCCL not support AVG if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { auto divisor = getSize(); @@ -1224,7 +1231,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( xcclReduceOp, root, comm, - xcclStream); + *xcclStream); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { @@ -1293,7 +1300,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather( (size_t)input.numel(), xcclDataType, comm, - xcclStream); + *xcclStream); return; }, [](at::xpu::XPUStream&, @@ -1377,7 +1384,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( (size_t)input.numel(), xcclDataType, comm, - xcclStream); + *xcclStream); return; }, OpType::_ALLGATHER_BASE, @@ -1405,7 +1412,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( (size_t)input.numel(), xcclDataType, comm, - xcclStream); + *xcclStream); return; }, OpType::COALESCED, @@ -1466,7 +1473,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( xcclDataType, xcclReduceOp, comm, - xcclStream); + *xcclStream); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG) { @@ -1561,7 +1568,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( xcclDataType, xcclReduceOp, comm, - xcclStream); + *xcclStream); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG) { @@ -1602,7 +1609,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( xcclDataType, xcclReduceOp, comm, - xcclStream); + *xcclStream); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG) { @@ -1647,7 +1654,6 @@ c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { barDevIdx = static_cast(rank_ % at::detail::getXPUHooks().getNumGPUs()); } - TORCH_CHECK_WITH( ValueError, barDevIdx >= 0, @@ -1717,7 +1723,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( (size_t)output.numel() / comm.size(), xcclDataType, comm, - xcclStream); + *xcclStream); return; }, OpType::ALLTOALL_BASE, @@ -1780,7 +1786,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( recvCounts, xcclDataType, comm, - xcclStream); + *xcclStream); return; }, OpType::ALLTOALL_BASE, @@ -1864,7 +1870,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( recvCounts, xcclDataType, comm, - xcclStream); + *xcclStream); if (!isOutputFlat) { ret_evt.wait(); diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index 59a6cd655..ecd9e422d 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -323,7 +323,9 @@ class TORCH_API ProcessGroupXCCL : public Backend { uint64_t getSequenceNumberForGroup() override; protected: - std::unordered_map> + std::unordered_map< + std::string, + std::pair>> xcclStreamsMap_; std::unordered_map xcclEventsMap_; std::unordered_map> devXCCLCommMap_; From 963cf3630c122e8d00a556f72896972a35b29fa8 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 10 Mar 2025 22:33:17 +0800 Subject: [PATCH 2/4] update --- src/xccl/ProcessGroupXCCL.cpp | 177 +++++++++++++--------------------- src/xccl/ProcessGroupXCCL.hpp | 4 +- 2 files changed, 70 insertions(+), 111 deletions(-) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 74d0ff4a3..647f1a617 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -352,11 +352,12 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( ccl::group_start(); } - auto xccl_stream = std::make_shared(ccl::create_stream(q)); + ccl::stream xccl_stream = ccl::create_stream(q); std::lock_guard lock(mutex_); devXCCLCommMap_.emplace(deviceKey, XCCLComm); xcclStreamsMap_.emplace( - deviceKey, std::make_pair(at::xpu::XPUStream(stream), xccl_stream)); + deviceKey, + std::make_pair(at::xpu::XPUStream(stream), std::move(xccl_stream))); xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); return XCCLComm; @@ -449,6 +450,7 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( } auto stream = xcclStreamsMap_.at(key).first; + auto cclstream = xcclStreamsMap_.at(key).second; syncStream(device, xcclEventsMap_[key], stream); c10::intrusive_ptr work; @@ -463,7 +465,7 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( for (const auto i : c10::irange(inputs.size())) { c10::xpu::XPUCachingAllocator::recordStream( inputs[i].storage().data_ptr(), stream); - fn(inputs[i], outputs[i], *comm, stream); + fn(inputs[i], outputs[i], *comm, stream, cclstream); } post(stream, work); @@ -530,6 +532,7 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( } auto stream = xcclStreamsMap_.at(key).first; + auto cclstream = xcclStreamsMap_.at(key).second; syncStream(device, xcclEventsMap_[key], stream); if (!coalescing_state_) { @@ -543,7 +546,7 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( c10::xpu::XPUCachingAllocator::recordStream( tensor.storage().data_ptr(), stream); - fn(tensor, *comm, stream, p2pTargetRank); + fn(tensor, *comm, stream, cclstream, p2pTargetRank); work->xcclEndEvent_->record(stream); work->blockingWait_ = blockingWait_; @@ -560,7 +563,7 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( c10::xpu::XPUCachingAllocator::recordStream( tensor.storage().data_ptr(), stream); - fn(tensor, *comm, stream, p2pTargetRank); + fn(tensor, *comm, stream, cclstream, p2pTargetRank); return nullptr; } @@ -597,21 +600,16 @@ c10::intrusive_ptr ProcessGroupXCCL::send( [&](at::Tensor& input, xcclComm_t& comm, at::xpu::XPUStream& stream, + ccl::stream& xcclStream, int dst) { auto xcclDataType = getXcclDataType(input.scalar_type()); - std::shared_ptr xcclStream = nullptr; - for (const auto& pair : xcclStreamsMap_) { - if (pair.second.first == stream) { - xcclStream = pair.second.second; - } - } ccl::send( input.data_ptr(), (size_t)input.numel(), xcclDataType, dst, comm, - *xcclStream); + xcclStream); return; }, dstRank, @@ -651,21 +649,16 @@ c10::intrusive_ptr ProcessGroupXCCL::recv( [&](at::Tensor& output, xcclComm_t& comm, at::xpu::XPUStream& stream, + ccl::stream& xcclStream, int src) { auto xcclDataType = getXcclDataType(output.scalar_type()); - std::shared_ptr xcclStream = nullptr; - for (const auto& pair : xcclStreamsMap_) { - if (pair.second.first == stream) { - xcclStream = pair.second.second; - } - } ccl::recv( output.data_ptr(), (size_t)output.numel(), xcclDataType, src, comm, - *xcclStream); + xcclStream); return; }, srcRank, @@ -743,7 +736,8 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( [&](at::Tensor& /* unused */, at::Tensor& /* unused */, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { const auto root = opts.rootRank; if (getRank() == root) { for (auto output : outputs) { @@ -753,9 +747,6 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( } { auto xcclDataType = getXcclDataType(inputTensor.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputs[0].device().index())) - .second; if (rank_ == root) { for (const auto r : c10::irange(size_)) { if (r != root) { @@ -766,7 +757,7 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( xcclDataType, r, comm, - *xcclStream); + xcclStream); } else { // on its own rank, simply copy from the input outputs[r].copy_(inputTensor); @@ -780,7 +771,7 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( xcclDataType, root, comm, - *xcclStream); + xcclStream); } return; } @@ -859,7 +850,8 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( [&](at::Tensor& /* unused */, at::Tensor& /* unused */, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { if (getRank() == root) { for (auto input : inputs) { c10::xpu::XPUCachingAllocator::recordStream( @@ -867,9 +859,6 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( } } { - auto xcclStream = - xcclStreamsMap_.at(std::to_string(outputs[0].device().index())) - .second; if (rank_ == root) { for (const auto r : c10::irange(size_)) { if (r != root) { @@ -882,7 +871,7 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( send_type, r, comm, - *xcclStream); + xcclStream); } else { // on its own rank, simply copy from the input outputTensor.copy_(inputs[r]); @@ -898,7 +887,7 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( recv_type, root, comm, - *xcclStream); + xcclStream); } return; @@ -916,11 +905,10 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(tensor.device().index())).second; ccl::allreduce( input.data_ptr(), output.data_ptr(), @@ -928,7 +916,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( xcclDataType, xcclReduceOp, comm, - *xcclStream); + xcclStream); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG) { @@ -972,11 +960,10 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(tensor.device().index())).second; ccl::allreduce( input.data_ptr(), output.data_ptr(), @@ -984,7 +971,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( xcclDataType, xcclReduceOp, comm, - *xcclStream); + xcclStream); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG) { @@ -1029,12 +1016,10 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(tensors[0].device().index())) - .second; ccl::allreduce( input.data_ptr(), output.data_ptr(), @@ -1042,7 +1027,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( xcclDataType, xcclReduceOp, comm, - *xcclStream); + xcclStream); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG) { @@ -1091,17 +1076,16 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(tensor.device().index())).second; ccl::broadcast( input.data_ptr(), (size_t)input.numel(), xcclDataType, root, comm, - *xcclStream); + xcclStream); return; }, OpType::BROADCAST, @@ -1124,11 +1108,9 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputTensor[0].device().index())) - .second; ccl::broadcast( input.data_ptr(), output.data_ptr(), @@ -1136,7 +1118,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( xcclDataType, root, comm, - *xcclStream); + xcclStream); return; }, OpType::BROADCAST, @@ -1173,12 +1155,11 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { const int root = opts.rootRank + opts.rootTensor; const auto xcclDataType = getXcclDataType(input.scalar_type(), true); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(tensor.device().index())).second; ccl::reduce( input.data_ptr(), output.data_ptr(), @@ -1187,7 +1168,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( xcclReduceOp, root, comm, - *xcclStream); + xcclStream); // WA due to oneCCL not support AVG if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { auto divisor = getSize(); @@ -1216,13 +1197,11 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { const int root = opts.rootRank + opts.rootTensor; const auto xcclDataType = getXcclDataType(input.scalar_type(), true); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputTensor[0].device().index())) - .second; ccl::reduce( input.data_ptr(), output.data_ptr(), @@ -1231,7 +1210,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( xcclReduceOp, root, comm, - *xcclStream); + xcclStream); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { @@ -1287,20 +1266,18 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputTensor.device().index())) - .second; ccl::allgather( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), xcclDataType, comm, - *xcclStream); + xcclStream); return; }, [](at::xpu::XPUStream&, @@ -1371,20 +1348,18 @@ c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(input_tensor.device().index())) - .second; ccl::allgather( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), xcclDataType, comm, - *xcclStream); + xcclStream); return; }, OpType::_ALLGATHER_BASE, @@ -1401,18 +1376,16 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputs[0].device().index())) - .second; ccl::allgather( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), xcclDataType, comm, - *xcclStream); + xcclStream); return; }, OpType::COALESCED, @@ -1457,15 +1430,12 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_ - .at(std::to_string(inputFlattened.device().index())) - .second; ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), @@ -1473,7 +1443,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( xcclDataType, xcclReduceOp, comm, - *xcclStream); + xcclStream); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG) { @@ -1553,14 +1523,12 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputTensor.device().index())) - .second; ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), @@ -1568,7 +1536,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( xcclDataType, xcclReduceOp, comm, - *xcclStream); + xcclStream); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG) { @@ -1594,14 +1562,12 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputs[0].device().index())) - .second; ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), @@ -1609,7 +1575,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( xcclDataType, xcclReduceOp, comm, - *xcclStream); + xcclStream); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG) { @@ -1654,6 +1620,7 @@ c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { barDevIdx = static_cast(rank_ % at::detail::getXPUHooks().getNumGPUs()); } + TORCH_CHECK_WITH( ValueError, barDevIdx >= 0, @@ -1710,20 +1677,18 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(output.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputTensor.device().index())) - .second; ccl::alltoall( input.data_ptr(), output.data_ptr(), (size_t)output.numel() / comm.size(), xcclDataType, comm, - *xcclStream); + xcclStream); return; }, OpType::ALLTOALL_BASE, @@ -1756,7 +1721,8 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { std::vector sendCounts(size_); std::vector recvCounts(size_); bool inputSplitsEqual = inputSplitSizes.size() == 0; @@ -1776,9 +1742,6 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( (outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen); } auto xcclDataType = getXcclDataType(output.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputTensor.device().index())) - .second; ccl::alltoallv( input.data_ptr(), sendCounts, @@ -1786,7 +1749,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( recvCounts, xcclDataType, comm, - *xcclStream); + xcclStream); return; }, OpType::ALLTOALL_BASE, @@ -1833,7 +1796,8 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( [&](at::Tensor& /* unused */, at::Tensor& /* unused */, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { c10::OptionalStreamGuard stream_guard(stream.unwrap()); at::Tensor flatInput; at::Tensor flatOutput; @@ -1859,9 +1823,6 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( } auto xcclDataType = getXcclDataType(flatOutput.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputTensors[0].device().index())) - .second; ccl::event ret_evt; ret_evt = ccl::alltoallv( flatInput.data_ptr(), @@ -1870,7 +1831,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( recvCounts, xcclDataType, comm, - *xcclStream); + xcclStream); if (!isOutputFlat) { ret_evt.wait(); diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index ecd9e422d..59a6cd655 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -323,9 +323,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { uint64_t getSequenceNumberForGroup() override; protected: - std::unordered_map< - std::string, - std::pair>> + std::unordered_map> xcclStreamsMap_; std::unordered_map xcclEventsMap_; std::unordered_map> devXCCLCommMap_; From 40ac8f7ef8d864213dd83d8653b7e098f3e5bf46 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 10 Mar 2025 23:28:58 +0800 Subject: [PATCH 3/4] add test case --- test/xpu/distributed/test_c10d_xccl.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/xpu/distributed/test_c10d_xccl.py b/test/xpu/distributed/test_c10d_xccl.py index dde51be68..eba34a3b5 100644 --- a/test/xpu/distributed/test_c10d_xccl.py +++ b/test/xpu/distributed/test_c10d_xccl.py @@ -494,6 +494,27 @@ def test_reduce_scatter_tensor_coalesced(self): dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i]) self.assertEqual(output_tensors, input_tensors[self.rank] * self.world_size) + @requires_xccl() + @skip_if_lt_x_gpu(2) + # The difference between this case and `test_send_recv` is that `test_send_recv` uses a previously created process group, + # whereas this case performs point-to-point operations immediately after creating the process group. + def test_single_p2p(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "xccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + torch.manual_seed(0) + send_tensor = torch.rand(10, 10).to(self.rank) + if self.rank == 0: + dist.send(send_tensor, 1) + if self.rank == 1: + recv_tensor = torch.rand(10, 10).to(self.rank) + dist.recv(recv_tensor, 0) + self.assertEqual(send_tensor, recv_tensor) + class SetDeviceMethod(Enum): TORCH_XPU_SET = auto() # torch.xpu.set_device From 61a4cad852d2584c64813414141347bde15e70db Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Tue, 11 Mar 2025 18:13:54 +0800 Subject: [PATCH 4/4] add comment --- src/xccl/ProcessGroupXCCL.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 647f1a617..3d93beb86 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -352,6 +352,13 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( ccl::group_start(); } + // The oneCCL group API requires retaining the SYCL queue (xcclstream) object + // within the lifecycle of the communicator. If the XPU stream is created + // within the collective operation, it would be destroyed earlier than the + // communicator after the operation ends. Therefore, the XPU stream is stored + // in a map alongside the communicator. Similarly, oneCCLv2 also requires + // retaining the SYCL queue pointer for collective operations, so this change + // will be necessary in oneCCLv2 as well. ccl::stream xccl_stream = ccl::create_stream(q); std::lock_guard lock(mutex_); devXCCLCommMap_.emplace(deviceKey, XCCLComm);