diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 5d16cab07..3d93beb86 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -352,6 +352,13 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( ccl::group_start(); } + // The oneCCL group API requires retaining the SYCL queue (xcclstream) object + // within the lifecycle of the communicator. If the XPU stream is created + // within the collective operation, it would be destroyed earlier than the + // communicator after the operation ends. Therefore, the XPU stream is stored + // in a map alongside the communicator. Similarly, oneCCLv2 also requires + // retaining the SYCL queue pointer for collective operations, so this change + // will be necessary in oneCCLv2 as well. ccl::stream xccl_stream = ccl::create_stream(q); std::lock_guard lock(mutex_); devXCCLCommMap_.emplace(deviceKey, XCCLComm); @@ -450,6 +457,7 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( } auto stream = xcclStreamsMap_.at(key).first; + auto cclstream = xcclStreamsMap_.at(key).second; syncStream(device, xcclEventsMap_[key], stream); c10::intrusive_ptr work; @@ -464,7 +472,7 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( for (const auto i : c10::irange(inputs.size())) { c10::xpu::XPUCachingAllocator::recordStream( inputs[i].storage().data_ptr(), stream); - fn(inputs[i], outputs[i], *comm, stream); + fn(inputs[i], outputs[i], *comm, stream, cclstream); } post(stream, work); @@ -531,6 +539,7 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( } auto stream = xcclStreamsMap_.at(key).first; + auto cclstream = xcclStreamsMap_.at(key).second; syncStream(device, xcclEventsMap_[key], stream); if (!coalescing_state_) { @@ -544,7 +553,7 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( c10::xpu::XPUCachingAllocator::recordStream( tensor.storage().data_ptr(), stream); - fn(tensor, *comm, stream, p2pTargetRank); + fn(tensor, *comm, stream, cclstream, p2pTargetRank); work->xcclEndEvent_->record(stream); work->blockingWait_ = blockingWait_; @@ -561,7 +570,7 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( c10::xpu::XPUCachingAllocator::recordStream( tensor.storage().data_ptr(), stream); - fn(tensor, *comm, stream, p2pTargetRank); + fn(tensor, *comm, stream, cclstream, p2pTargetRank); return nullptr; } @@ -598,10 +607,9 @@ c10::intrusive_ptr ProcessGroupXCCL::send( [&](at::Tensor& input, xcclComm_t& comm, at::xpu::XPUStream& stream, + ccl::stream& xcclStream, int dst) { auto xcclDataType = getXcclDataType(input.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(tensor.device().index())).second; ccl::send( input.data_ptr(), (size_t)input.numel(), @@ -648,10 +656,9 @@ c10::intrusive_ptr ProcessGroupXCCL::recv( [&](at::Tensor& output, xcclComm_t& comm, at::xpu::XPUStream& stream, + ccl::stream& xcclStream, int src) { auto xcclDataType = getXcclDataType(output.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(tensor.device().index())).second; ccl::recv( output.data_ptr(), (size_t)output.numel(), @@ -736,7 +743,8 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( [&](at::Tensor& /* unused */, at::Tensor& /* unused */, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { const auto root = opts.rootRank; if (getRank() == root) { for (auto output : outputs) { @@ -746,9 +754,6 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( } { auto xcclDataType = getXcclDataType(inputTensor.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputs[0].device().index())) - .second; if (rank_ == root) { for (const auto r : c10::irange(size_)) { if (r != root) { @@ -852,7 +857,8 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( [&](at::Tensor& /* unused */, at::Tensor& /* unused */, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { if (getRank() == root) { for (auto input : inputs) { c10::xpu::XPUCachingAllocator::recordStream( @@ -860,9 +866,6 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( } } { - auto xcclStream = - xcclStreamsMap_.at(std::to_string(outputs[0].device().index())) - .second; if (rank_ == root) { for (const auto r : c10::irange(size_)) { if (r != root) { @@ -909,11 +912,10 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(tensor.device().index())).second; ccl::allreduce( input.data_ptr(), output.data_ptr(), @@ -965,11 +967,10 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(tensor.device().index())).second; ccl::allreduce( input.data_ptr(), output.data_ptr(), @@ -1022,12 +1023,10 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(tensors[0].device().index())) - .second; ccl::allreduce( input.data_ptr(), output.data_ptr(), @@ -1084,10 +1083,9 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(tensor.device().index())).second; ccl::broadcast( input.data_ptr(), (size_t)input.numel(), @@ -1117,11 +1115,9 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputTensor[0].device().index())) - .second; ccl::broadcast( input.data_ptr(), output.data_ptr(), @@ -1166,12 +1162,11 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { const int root = opts.rootRank + opts.rootTensor; const auto xcclDataType = getXcclDataType(input.scalar_type(), true); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(tensor.device().index())).second; ccl::reduce( input.data_ptr(), output.data_ptr(), @@ -1209,13 +1204,11 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { const int root = opts.rootRank + opts.rootTensor; const auto xcclDataType = getXcclDataType(input.scalar_type(), true); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputTensor[0].device().index())) - .second; ccl::reduce( input.data_ptr(), output.data_ptr(), @@ -1280,13 +1273,11 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputTensor.device().index())) - .second; ccl::allgather( input.data_ptr(), output.data_ptr(), @@ -1364,13 +1355,11 @@ c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(input_tensor.device().index())) - .second; ccl::allgather( input.data_ptr(), output.data_ptr(), @@ -1394,11 +1383,9 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputs[0].device().index())) - .second; ccl::allgather( input.data_ptr(), output.data_ptr(), @@ -1450,15 +1437,12 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_ - .at(std::to_string(inputFlattened.device().index())) - .second; ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), @@ -1546,14 +1530,12 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputTensor.device().index())) - .second; ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), @@ -1587,14 +1569,12 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputs[0].device().index())) - .second; ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), @@ -1704,13 +1684,11 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(output.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputTensor.device().index())) - .second; ccl::alltoall( input.data_ptr(), output.data_ptr(), @@ -1750,7 +1728,8 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( [&](at::Tensor& input, at::Tensor& output, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { std::vector sendCounts(size_); std::vector recvCounts(size_); bool inputSplitsEqual = inputSplitSizes.size() == 0; @@ -1770,9 +1749,6 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( (outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen); } auto xcclDataType = getXcclDataType(output.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputTensor.device().index())) - .second; ccl::alltoallv( input.data_ptr(), sendCounts, @@ -1827,7 +1803,8 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( [&](at::Tensor& /* unused */, at::Tensor& /* unused */, xcclComm_t& comm, - at::xpu::XPUStream& stream) { + at::xpu::XPUStream& stream, + ccl::stream& xcclStream) { c10::OptionalStreamGuard stream_guard(stream.unwrap()); at::Tensor flatInput; at::Tensor flatOutput; @@ -1853,9 +1830,6 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( } auto xcclDataType = getXcclDataType(flatOutput.scalar_type()); - auto xcclStream = - xcclStreamsMap_.at(std::to_string(inputTensors[0].device().index())) - .second; ccl::event ret_evt; ret_evt = ccl::alltoallv( flatInput.data_ptr(), diff --git a/test/xpu/distributed/test_c10d_xccl.py b/test/xpu/distributed/test_c10d_xccl.py index dde51be68..eba34a3b5 100644 --- a/test/xpu/distributed/test_c10d_xccl.py +++ b/test/xpu/distributed/test_c10d_xccl.py @@ -494,6 +494,27 @@ def test_reduce_scatter_tensor_coalesced(self): dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i]) self.assertEqual(output_tensors, input_tensors[self.rank] * self.world_size) + @requires_xccl() + @skip_if_lt_x_gpu(2) + # The difference between this case and `test_send_recv` is that `test_send_recv` uses a previously created process group, + # whereas this case performs point-to-point operations immediately after creating the process group. + def test_single_p2p(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "xccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + torch.manual_seed(0) + send_tensor = torch.rand(10, 10).to(self.rank) + if self.rank == 0: + dist.send(send_tensor, 1) + if self.rank == 1: + recv_tensor = torch.rand(10, 10).to(self.rank) + dist.recv(recv_tensor, 0) + self.assertEqual(send_tensor, recv_tensor) + class SetDeviceMethod(Enum): TORCH_XPU_SET = auto() # torch.xpu.set_device