Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix singlep2p cclstream mapping #1445

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 43 additions & 76 deletions src/xccl/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
}

auto stream = xcclStreamsMap_.at(key).first;
auto cclstream = xcclStreamsMap_.at(key).second;
syncStream(device, xcclEventsMap_[key], stream);

c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
Expand All @@ -464,7 +465,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
for (const auto i : c10::irange(inputs.size())) {
c10::xpu::XPUCachingAllocator::recordStream(
inputs[i].storage().data_ptr(), stream);
fn(inputs[i], outputs[i], *comm, stream);
fn(inputs[i], outputs[i], *comm, stream, cclstream);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some comments to explain the background. In future, we still need to change it in new C API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment line 355.

}

post(stream, work);
Expand Down Expand Up @@ -531,6 +532,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
}

auto stream = xcclStreamsMap_.at(key).first;
auto cclstream = xcclStreamsMap_.at(key).second;
syncStream(device, xcclEventsMap_[key], stream);

if (!coalescing_state_) {
Expand All @@ -544,7 +546,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
c10::xpu::XPUCachingAllocator::recordStream(
tensor.storage().data_ptr(), stream);

fn(tensor, *comm, stream, p2pTargetRank);
fn(tensor, *comm, stream, cclstream, p2pTargetRank);

work->xcclEndEvent_->record(stream);
work->blockingWait_ = blockingWait_;
Expand All @@ -561,7 +563,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
c10::xpu::XPUCachingAllocator::recordStream(
tensor.storage().data_ptr(), stream);

fn(tensor, *comm, stream, p2pTargetRank);
fn(tensor, *comm, stream, cclstream, p2pTargetRank);

return nullptr;
}
Expand Down Expand Up @@ -598,10 +600,9 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::send(
[&](at::Tensor& input,
xcclComm_t& comm,
at::xpu::XPUStream& stream,
ccl::stream& xcclStream,
int dst) {
auto xcclDataType = getXcclDataType(input.scalar_type());
auto xcclStream =
xcclStreamsMap_.at(std::to_string(tensor.device().index())).second;
ccl::send(
input.data_ptr(),
(size_t)input.numel(),
Expand Down Expand Up @@ -648,10 +649,9 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::recv(
[&](at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream,
ccl::stream& xcclStream,
int src) {
auto xcclDataType = getXcclDataType(output.scalar_type());
auto xcclStream =
xcclStreamsMap_.at(std::to_string(tensor.device().index())).second;
ccl::recv(
output.data_ptr(),
(size_t)output.numel(),
Expand Down Expand Up @@ -736,7 +736,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::gather(
[&](at::Tensor& /* unused */,
at::Tensor& /* unused */,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
const auto root = opts.rootRank;
if (getRank() == root) {
for (auto output : outputs) {
Expand All @@ -746,9 +747,6 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::gather(
}
{
auto xcclDataType = getXcclDataType(inputTensor.scalar_type());
auto xcclStream =
xcclStreamsMap_.at(std::to_string(inputs[0].device().index()))
.second;
if (rank_ == root) {
for (const auto r : c10::irange(size_)) {
if (r != root) {
Expand Down Expand Up @@ -852,17 +850,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::scatter(
[&](at::Tensor& /* unused */,
at::Tensor& /* unused */,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
if (getRank() == root) {
for (auto input : inputs) {
c10::xpu::XPUCachingAllocator::recordStream(
input.storage().data_ptr(), stream);
}
}
{
auto xcclStream =
xcclStreamsMap_.at(std::to_string(outputs[0].device().index()))
.second;
if (rank_ == root) {
for (const auto r : c10::irange(size_)) {
if (r != root) {
Expand Down Expand Up @@ -909,11 +905,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_impl(
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclStream =
xcclStreamsMap_.at(std::to_string(tensor.device().index())).second;
ccl::allreduce(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -965,11 +960,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclStream =
xcclStreamsMap_.at(std::to_string(tensor.device().index())).second;
ccl::allreduce(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1022,12 +1016,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclStream =
xcclStreamsMap_.at(std::to_string(tensors[0].device().index()))
.second;
ccl::allreduce(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1084,10 +1076,9 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::broadcast(
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type());
auto xcclStream =
xcclStreamsMap_.at(std::to_string(tensor.device().index())).second;
ccl::broadcast(
input.data_ptr(),
(size_t)input.numel(),
Expand Down Expand Up @@ -1117,11 +1108,9 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_broadcast_oop(
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type());
auto xcclStream =
xcclStreamsMap_.at(std::to_string(inputTensor[0].device().index()))
.second;
ccl::broadcast(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1166,12 +1155,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce(
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
const int root = opts.rootRank + opts.rootTensor;
const auto xcclDataType = getXcclDataType(input.scalar_type(), true);
const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclStream =
xcclStreamsMap_.at(std::to_string(tensor.device().index())).second;
ccl::reduce(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1209,13 +1197,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_oop(
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
const int root = opts.rootRank + opts.rootTensor;
const auto xcclDataType = getXcclDataType(input.scalar_type(), true);
const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclStream =
xcclStreamsMap_.at(std::to_string(inputTensor[0].device().index()))
.second;
ccl::reduce(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1280,13 +1266,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allgather(
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
c10::xpu::XPUCachingAllocator::recordStream(
output.storage().data_ptr(), stream);
auto xcclDataType = getXcclDataType(input.scalar_type());
auto xcclStream =
xcclStreamsMap_.at(std::to_string(inputTensor.device().index()))
.second;
ccl::allgather(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1364,13 +1348,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_allgather_base(
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
c10::xpu::XPUCachingAllocator::recordStream(
output.storage().data_ptr(), stream);
auto xcclDataType = getXcclDataType(input.scalar_type());
auto xcclStream =
xcclStreamsMap_.at(std::to_string(input_tensor.device().index()))
.second;
ccl::allgather(
input.data_ptr(),
output.data_ptr(),
Expand All @@ -1394,11 +1376,9 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allgather_into_tensor_coalesced(
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type());
auto xcclStream =
xcclStreamsMap_.at(std::to_string(inputs[0].device().index()))
.second;
ccl::allgather(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1450,15 +1430,12 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter(
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
c10::xpu::XPUCachingAllocator::recordStream(
output.storage().data_ptr(), stream);
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclStream =
xcclStreamsMap_
.at(std::to_string(inputFlattened.device().index()))
.second;
ccl::reduce_scatter(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1546,14 +1523,12 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_scatter_base(
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
c10::xpu::XPUCachingAllocator::recordStream(
output.storage().data_ptr(), stream);
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclStream =
xcclStreamsMap_.at(std::to_string(inputTensor.device().index()))
.second;
ccl::reduce_scatter(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1587,14 +1562,12 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter_tensor_coalesced(
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
c10::xpu::XPUCachingAllocator::recordStream(
output.storage().data_ptr(), stream);
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclStream =
xcclStreamsMap_.at(std::to_string(inputs[0].device().index()))
.second;
ccl::reduce_scatter(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1704,13 +1677,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall_base(
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
c10::xpu::XPUCachingAllocator::recordStream(
output.storage().data_ptr(), stream);
auto xcclDataType = getXcclDataType(output.scalar_type());
auto xcclStream =
xcclStreamsMap_.at(std::to_string(inputTensor.device().index()))
.second;
ccl::alltoall(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1750,7 +1721,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall_base(
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
std::vector<size_t> sendCounts(size_);
std::vector<size_t> recvCounts(size_);
bool inputSplitsEqual = inputSplitSizes.size() == 0;
Expand All @@ -1770,9 +1742,6 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall_base(
(outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen);
}
auto xcclDataType = getXcclDataType(output.scalar_type());
auto xcclStream =
xcclStreamsMap_.at(std::to_string(inputTensor.device().index()))
.second;
ccl::alltoallv(
input.data_ptr(),
sendCounts,
Expand Down Expand Up @@ -1827,7 +1796,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall(
[&](at::Tensor& /* unused */,
at::Tensor& /* unused */,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
c10::OptionalStreamGuard stream_guard(stream.unwrap());
at::Tensor flatInput;
at::Tensor flatOutput;
Expand All @@ -1853,9 +1823,6 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall(
}

auto xcclDataType = getXcclDataType(flatOutput.scalar_type());
auto xcclStream =
xcclStreamsMap_.at(std::to_string(inputTensors[0].device().index()))
.second;
ccl::event ret_evt;
ret_evt = ccl::alltoallv(
flatInput.data_ptr(),
Expand Down
21 changes: 21 additions & 0 deletions test/xpu/distributed/test_c10d_xccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,27 @@ def test_reduce_scatter_tensor_coalesced(self):
dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
self.assertEqual(output_tensors, input_tensors[self.rank] * self.world_size)

@requires_xccl()
@skip_if_lt_x_gpu(2)
# The difference between this case and `test_send_recv` is that `test_send_recv` uses a previously created process group,
# whereas this case performs point-to-point operations immediately after creating the process group.
def test_single_p2p(self):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"xccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
torch.manual_seed(0)
send_tensor = torch.rand(10, 10).to(self.rank)
if self.rank == 0:
dist.send(send_tensor, 1)
if self.rank == 1:
recv_tensor = torch.rand(10, 10).to(self.rank)
dist.recv(recv_tensor, 0)
self.assertEqual(send_tensor, recv_tensor)


class SetDeviceMethod(Enum):
TORCH_XPU_SET = auto() # torch.xpu.set_device
Expand Down
Loading