@@ -897,6 +897,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::scatter(
897
897
898
898
c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_impl (
899
899
at::Tensor& tensor,
900
+ const char * profilingTitle,
900
901
const AllreduceOptions& opts) {
901
902
return collective (
902
903
tensor,
@@ -928,7 +929,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_impl(
928
929
return ;
929
930
},
930
931
OpType::ALLREDUCE,
931
- " xccl:all_reduce " );
932
+ profilingTitle );
932
933
}
933
934
934
935
c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce (
@@ -956,36 +957,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
956
957
-1 , // globalRankStride
957
958
size_); // worldSize
958
959
959
- return collective (
960
- tensor,
961
- tensor,
962
- [&](at::Tensor& input,
963
- at::Tensor& output,
964
- xcclComm_t& comm,
965
- at::xpu::XPUStream& stream) {
966
- auto xcclDataType = getXcclDataType (input.scalar_type (), true );
967
- auto xcclReduceOp = getXcclReduceOp (opts.reduceOp , input);
968
- ccl::allreduce (
969
- input.data_ptr (),
970
- output.data_ptr (),
971
- (size_t )input.numel (),
972
- xcclDataType,
973
- xcclReduceOp,
974
- comm,
975
- ccl::create_stream (stream.queue ()));
976
- #if !defined(XCCL_HAS_AVG)
977
- if (opts.reduceOp == ReduceOp::AVG) {
978
- auto divisor = getSize ();
979
- c10::StreamGuard guard (stream);
980
- c10::xpu::XPUCachingAllocator::recordStream (
981
- output.storage ().data_ptr (), stream);
982
- output.div_ (divisor);
983
- }
984
- #endif
985
- return ;
986
- },
987
- OpType::ALLREDUCE,
988
- " xccl:all_reduce" );
960
+ return allreduce_impl (tensor, " xccl:all_reduce" , opts);
989
961
}
990
962
991
963
c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced (
@@ -1621,7 +1593,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::barrier(const BarrierOptions& opts) {
1621
1593
at::Tensor barrierTensor =
1622
1594
at::zeros ({1 }, at::TensorOptions ().device (barDevice).dtype (at::kFloat ));
1623
1595
1624
- auto work = allreduce_impl (barrierTensor);
1596
+ auto work = allreduce_impl (barrierTensor, " xccl:all_reduce_barrier " );
1625
1597
1626
1598
auto xcclWork = dynamic_cast <ProcessGroupXCCL::WorkXCCL*>(work.get ());
1627
1599
TORCH_CHECK (xcclWork);
0 commit comments