Skip to content

Commit f07f6e5

Browse files
committed
add version select
1 parent ea9a745 commit f07f6e5

File tree

1 file changed

+89
-4
lines changed

1 file changed

+89
-4
lines changed

src/xccl/ProcessGroupXCCL.cpp

+89-4
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,21 @@
77
namespace c10d {
88

99
namespace {
10+
11+
#if defined(CCL_MAJOR_VERSION) && \
12+
((CCL_MAJOR_VERSION > 2021) || \
13+
(CCL_MAJOR_VERSION == 2021) && (CCL_MINOR_VERSION >= 15))
14+
#define XCCL_HAS_AVG 1
15+
#endif // oneCCL version >= 2021.15
16+
1017
const std::map<c10d::ReduceOp, ccl::reduction> xcclOps = {
1118
{ReduceOp::MIN, ccl::reduction::min},
1219
{ReduceOp::MAX, ccl::reduction::max},
1320
{ReduceOp::SUM, ccl::reduction::sum},
14-
{ReduceOp::AVG, ccl::reduction::avg},
1521
{ReduceOp::PRODUCT, ccl::reduction::prod},
22+
#ifdef XCCL_HAS_AVG
23+
{ReduceOp::AVG, ccl::reduction::avg},
24+
#endif // XCCL_HAS_AVG
1625
};
1726

1827
const std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
@@ -144,10 +153,23 @@ ccl::datatype getXcclDataType(
144153

145154
ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
146155
try {
147-
if (input.scalar_type() == at::kBool && reduceOp == ReduceOp::SUM) {
148-
// Map sum to max for bool tensors to avoid overflow issues with sum.
149-
return ccl::reduction::max;
156+
if (input.scalar_type() == at::kBool) {
157+
if (reduceOp == ReduceOp::SUM) {
158+
// Map sum to max for bool tensors to avoid overflow issues with sum.
159+
return ccl::reduction::max;
160+
}
161+
#ifdef XCCL_HAS_AVG
162+
if (reduceOp == ReduceOp::AVG) {
163+
C10_THROW_ERROR(
164+
TypeError, "Cannot use ReduceOp.AVG with boolean inputs");
165+
}
166+
#endif // XCCL_HAS_AVG
167+
}
168+
#if !defined(XCCL_HAS_AVG)
169+
if (reduceOp == ReduceOp::AVG) {
170+
return ccl::reduction::sum;
150171
}
172+
#endif
151173
return xcclOps.at(reduceOp);
152174
} catch (const std::out_of_range&) {
153175
C10_THROW_ERROR(
@@ -894,6 +916,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_impl(
894916
xcclReduceOp,
895917
comm,
896918
ccl::create_stream(stream.queue()));
919+
#if !defined(XCCL_HAS_AVG)
920+
if (opts.reduceOp == ReduceOp::AVG) {
921+
auto divisor = getSize();
922+
c10::StreamGuard guard(stream);
923+
c10::xpu::XPUCachingAllocator::recordStream(
924+
output.storage().data_ptr(), stream);
925+
output.div_(divisor);
926+
}
927+
#endif
897928
return;
898929
},
899930
OpType::ALLREDUCE,
@@ -988,6 +1019,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
9881019
xcclReduceOp,
9891020
comm,
9901021
ccl::create_stream(stream.queue()));
1022+
#if !defined(XCCL_HAS_AVG)
1023+
if (opts.reduceOp == ReduceOp::AVG) {
1024+
auto divisor = getSize();
1025+
c10::StreamGuard guard(stream);
1026+
c10::xpu::XPUCachingAllocator::recordStream(
1027+
output.storage().data_ptr(), stream);
1028+
output.div_(divisor);
1029+
}
1030+
#endif
9911031
return;
9921032
},
9931033
OpType::COALESCED,
@@ -1117,6 +1157,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce(
11171157
root,
11181158
comm,
11191159
ccl::create_stream(stream.queue()));
1160+
#if !defined(XCCL_HAS_AVG)
1161+
if (opts.reduceOp == ReduceOp::AVG && getRank() == root) {
1162+
auto divisor = getSize();
1163+
c10::StreamGuard guard(stream);
1164+
c10::xpu::XPUCachingAllocator::recordStream(
1165+
output.storage().data_ptr(), stream);
1166+
output.div_(divisor);
1167+
}
1168+
#endif
11201169
return;
11211170
},
11221171
OpType::REDUCE,
@@ -1150,6 +1199,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_oop(
11501199
root,
11511200
comm,
11521201
ccl::create_stream(stream.queue()));
1202+
#if !defined(XCCL_HAS_AVG)
1203+
if (opts.reduceOp == ReduceOp::AVG && getRank() == root) {
1204+
auto divisor = getSize();
1205+
c10::StreamGuard guard(stream);
1206+
c10::xpu::XPUCachingAllocator::recordStream(
1207+
output.storage().data_ptr(), stream);
1208+
output.div_(divisor);
1209+
}
1210+
#endif
11531211
return;
11541212
},
11551213
OpType::REDUCE,
@@ -1370,6 +1428,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter(
13701428
xcclReduceOp,
13711429
comm,
13721430
ccl::create_stream(stream.queue()));
1431+
#if !defined(XCCL_HAS_AVG)
1432+
if (opts.reduceOp == ReduceOp::AVG) {
1433+
auto divisor = getSize();
1434+
c10::StreamGuard guard(stream);
1435+
c10::xpu::XPUCachingAllocator::recordStream(
1436+
output.storage().data_ptr(), stream);
1437+
output.div_(divisor);
1438+
}
1439+
#endif
13731440
return;
13741441
},
13751442
[&](at::xpu::XPUStream& Stream,
@@ -1453,6 +1520,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_scatter_base(
14531520
xcclReduceOp,
14541521
comm,
14551522
ccl::create_stream(stream.queue()));
1523+
#if !defined(XCCL_HAS_AVG)
1524+
if (opts.reduceOp == ReduceOp::AVG) {
1525+
auto divisor = getSize();
1526+
c10::StreamGuard guard(stream);
1527+
c10::xpu::XPUCachingAllocator::recordStream(
1528+
output.storage().data_ptr(), stream);
1529+
output.div_(divisor);
1530+
}
1531+
#endif
14561532
return;
14571533
},
14581534
OpType::_REDUCE_SCATTER_BASE,
@@ -1482,6 +1558,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter_tensor_coalesced(
14821558
xcclReduceOp,
14831559
comm,
14841560
ccl::create_stream(stream.queue()));
1561+
#if !defined(XCCL_HAS_AVG)
1562+
if (opts.reduceOp == ReduceOp::AVG) {
1563+
auto divisor = getSize();
1564+
c10::StreamGuard guard(stream);
1565+
c10::xpu::XPUCachingAllocator::recordStream(
1566+
output.storage().data_ptr(), stream);
1567+
output.div_(divisor);
1568+
}
1569+
#endif
14851570
return;
14861571
},
14871572
OpType::COALESCED,

0 commit comments

Comments
 (0)