|
7 | 7 | namespace c10d {
|
8 | 8 |
|
9 | 9 | namespace {
|
| 10 | + |
| 11 | +#if defined(CCL_MAJOR_VERSION) && \ |
| 12 | + ((CCL_MAJOR_VERSION > 2021) || \ |
| 13 | + (CCL_MAJOR_VERSION == 2021) && (CCL_MINOR_VERSION >= 15)) |
| 14 | +#define XCCL_HAS_AVG 1 |
| 15 | +#endif // oneCCL version >= 2021.15 |
| 16 | + |
10 | 17 | const std::map<c10d::ReduceOp, ccl::reduction> xcclOps = {
|
11 | 18 | {ReduceOp::MIN, ccl::reduction::min},
|
12 | 19 | {ReduceOp::MAX, ccl::reduction::max},
|
13 | 20 | {ReduceOp::SUM, ccl::reduction::sum},
|
14 |
| - {ReduceOp::AVG, ccl::reduction::avg}, |
15 | 21 | {ReduceOp::PRODUCT, ccl::reduction::prod},
|
| 22 | +#ifdef XCCL_HAS_AVG |
| 23 | + {ReduceOp::AVG, ccl::reduction::avg}, |
| 24 | +#endif // XCCL_HAS_AVG |
16 | 25 | };
|
17 | 26 |
|
18 | 27 | const std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
|
@@ -144,10 +153,23 @@ ccl::datatype getXcclDataType(
|
144 | 153 |
|
145 | 154 | ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
|
146 | 155 | try {
|
147 |
| - if (input.scalar_type() == at::kBool && reduceOp == ReduceOp::SUM) { |
148 |
| - // Map sum to max for bool tensors to avoid overflow issues with sum. |
149 |
| - return ccl::reduction::max; |
| 156 | + if (input.scalar_type() == at::kBool) { |
| 157 | + if (reduceOp == ReduceOp::SUM) { |
| 158 | + // Map sum to max for bool tensors to avoid overflow issues with sum. |
| 159 | + return ccl::reduction::max; |
| 160 | + } |
| 161 | +#ifdef XCCL_HAS_AVG |
| 162 | + if (reduceOp == ReduceOp::AVG) { |
| 163 | + C10_THROW_ERROR( |
| 164 | + TypeError, "Cannot use ReduceOp.AVG with boolean inputs"); |
| 165 | + } |
| 166 | +#endif // XCCL_HAS_AVG |
| 167 | + } |
| 168 | +#if !defined(XCCL_HAS_AVG) |
| 169 | + if (reduceOp == ReduceOp::AVG) { |
| 170 | + return ccl::reduction::sum; |
150 | 171 | }
|
| 172 | +#endif |
151 | 173 | return xcclOps.at(reduceOp);
|
152 | 174 | } catch (const std::out_of_range&) {
|
153 | 175 | C10_THROW_ERROR(
|
@@ -894,6 +916,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_impl(
|
894 | 916 | xcclReduceOp,
|
895 | 917 | comm,
|
896 | 918 | ccl::create_stream(stream.queue()));
|
| 919 | +#if !defined(XCCL_HAS_AVG) |
| 920 | + if (opts.reduceOp == ReduceOp::AVG) { |
| 921 | + auto divisor = getSize(); |
| 922 | + c10::StreamGuard guard(stream); |
| 923 | + c10::xpu::XPUCachingAllocator::recordStream( |
| 924 | + output.storage().data_ptr(), stream); |
| 925 | + output.div_(divisor); |
| 926 | + } |
| 927 | +#endif |
897 | 928 | return;
|
898 | 929 | },
|
899 | 930 | OpType::ALLREDUCE,
|
@@ -988,6 +1019,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
|
988 | 1019 | xcclReduceOp,
|
989 | 1020 | comm,
|
990 | 1021 | ccl::create_stream(stream.queue()));
|
| 1022 | +#if !defined(XCCL_HAS_AVG) |
| 1023 | + if (opts.reduceOp == ReduceOp::AVG) { |
| 1024 | + auto divisor = getSize(); |
| 1025 | + c10::StreamGuard guard(stream); |
| 1026 | + c10::xpu::XPUCachingAllocator::recordStream( |
| 1027 | + output.storage().data_ptr(), stream); |
| 1028 | + output.div_(divisor); |
| 1029 | + } |
| 1030 | +#endif |
991 | 1031 | return;
|
992 | 1032 | },
|
993 | 1033 | OpType::COALESCED,
|
@@ -1117,6 +1157,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce(
|
1117 | 1157 | root,
|
1118 | 1158 | comm,
|
1119 | 1159 | ccl::create_stream(stream.queue()));
|
| 1160 | +#if !defined(XCCL_HAS_AVG) |
| 1161 | + if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { |
| 1162 | + auto divisor = getSize(); |
| 1163 | + c10::StreamGuard guard(stream); |
| 1164 | + c10::xpu::XPUCachingAllocator::recordStream( |
| 1165 | + output.storage().data_ptr(), stream); |
| 1166 | + output.div_(divisor); |
| 1167 | + } |
| 1168 | +#endif |
1120 | 1169 | return;
|
1121 | 1170 | },
|
1122 | 1171 | OpType::REDUCE,
|
@@ -1150,6 +1199,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_oop(
|
1150 | 1199 | root,
|
1151 | 1200 | comm,
|
1152 | 1201 | ccl::create_stream(stream.queue()));
|
| 1202 | +#if !defined(XCCL_HAS_AVG) |
| 1203 | + if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { |
| 1204 | + auto divisor = getSize(); |
| 1205 | + c10::StreamGuard guard(stream); |
| 1206 | + c10::xpu::XPUCachingAllocator::recordStream( |
| 1207 | + output.storage().data_ptr(), stream); |
| 1208 | + output.div_(divisor); |
| 1209 | + } |
| 1210 | +#endif |
1153 | 1211 | return;
|
1154 | 1212 | },
|
1155 | 1213 | OpType::REDUCE,
|
@@ -1370,6 +1428,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter(
|
1370 | 1428 | xcclReduceOp,
|
1371 | 1429 | comm,
|
1372 | 1430 | ccl::create_stream(stream.queue()));
|
| 1431 | +#if !defined(XCCL_HAS_AVG) |
| 1432 | + if (opts.reduceOp == ReduceOp::AVG) { |
| 1433 | + auto divisor = getSize(); |
| 1434 | + c10::StreamGuard guard(stream); |
| 1435 | + c10::xpu::XPUCachingAllocator::recordStream( |
| 1436 | + output.storage().data_ptr(), stream); |
| 1437 | + output.div_(divisor); |
| 1438 | + } |
| 1439 | +#endif |
1373 | 1440 | return;
|
1374 | 1441 | },
|
1375 | 1442 | [&](at::xpu::XPUStream& Stream,
|
@@ -1453,6 +1520,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_scatter_base(
|
1453 | 1520 | xcclReduceOp,
|
1454 | 1521 | comm,
|
1455 | 1522 | ccl::create_stream(stream.queue()));
|
| 1523 | +#if !defined(XCCL_HAS_AVG) |
| 1524 | + if (opts.reduceOp == ReduceOp::AVG) { |
| 1525 | + auto divisor = getSize(); |
| 1526 | + c10::StreamGuard guard(stream); |
| 1527 | + c10::xpu::XPUCachingAllocator::recordStream( |
| 1528 | + output.storage().data_ptr(), stream); |
| 1529 | + output.div_(divisor); |
| 1530 | + } |
| 1531 | +#endif |
1456 | 1532 | return;
|
1457 | 1533 | },
|
1458 | 1534 | OpType::_REDUCE_SCATTER_BASE,
|
@@ -1482,6 +1558,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter_tensor_coalesced(
|
1482 | 1558 | xcclReduceOp,
|
1483 | 1559 | comm,
|
1484 | 1560 | ccl::create_stream(stream.queue()));
|
| 1561 | +#if !defined(XCCL_HAS_AVG) |
| 1562 | + if (opts.reduceOp == ReduceOp::AVG) { |
| 1563 | + auto divisor = getSize(); |
| 1564 | + c10::StreamGuard guard(stream); |
| 1565 | + c10::xpu::XPUCachingAllocator::recordStream( |
| 1566 | + output.storage().data_ptr(), stream); |
| 1567 | + output.div_(divisor); |
| 1568 | + } |
| 1569 | +#endif |
1485 | 1570 | return;
|
1486 | 1571 | },
|
1487 | 1572 | OpType::COALESCED,
|
|
0 commit comments