From f99208fa7a36c482f50f3c5380b01f057a91bc6e Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 2 Jan 2025 18:16:15 +0200 Subject: [PATCH 01/11] TL/MLX5: minor clean and profiler --- src/components/tl/mlx5/alltoall/alltoall_coll.c | 12 ++++++++++-- src/components/tl/mlx5/tl_mlx5_team.c | 4 ++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/components/tl/mlx5/alltoall/alltoall_coll.c b/src/components/tl/mlx5/alltoall/alltoall_coll.c index 70439263fb..6e8cb0ec5d 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_coll.c +++ b/src/components/tl/mlx5/alltoall/alltoall_coll.c @@ -243,6 +243,10 @@ static ucc_status_t ucc_tl_mlx5_fanout_start(ucc_coll_task_t *coll_task) tl_debug(UCC_TASK_LIB(task), "fanout start"); /* start task if completion event received */ UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", 0); + if (team->a2a->node.sbgp->group_rank == team->a2a->node.asr_rank) { + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_wait-on-data_start", 0); + } /* Start fanout */ ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, coll_task); return UCC_OK; @@ -265,6 +269,8 @@ static void ucc_tl_mlx5_fanout_progress(ucc_coll_task_t *coll_task) coll_task->status = UCC_INPROGRESS; return; } + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_wait-on-data_complete, fanout_start", 0); } if (UCC_OK == ucc_tl_mlx5_node_fanout(team, task)) { @@ -342,12 +348,14 @@ static ucc_status_t ucc_tl_mlx5_asr_barrier_start(ucc_coll_task_t *coll_task) status = send_done(team, i); } if (status != UCC_OK) { - tl_error(UCC_TASK_LIB(task), "failed sending barrier notice"); + tl_error(UCC_TASK_LIB(task), "failed sending barrier notice"); return status; } + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_barrier_send_posted", 0); } coll_task->status = UCC_OK; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barreir_done", + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_done", 0); return ucc_task_complete(coll_task); } diff --git a/src/components/tl/mlx5/tl_mlx5_team.c b/src/components/tl/mlx5/tl_mlx5_team.c index 1e5f6ddf56..6a65274d1d 100644 --- a/src/components/tl/mlx5/tl_mlx5_team.c +++ b/src/components/tl/mlx5/tl_mlx5_team.c @@ -117,7 +117,7 @@ ucc_status_t ucc_tl_mlx5_team_destroy(ucc_base_team_t *tl_team) return UCC_OK; } -static inline ucc_status_t ucc_tl_mlx5_a2a_team_test(ucc_base_team_t *team) +static inline ucc_status_t ucc_tl_mlx5_alltoall_team_test(ucc_base_team_t *team) { ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); @@ -253,7 +253,7 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) goto initial_sync_post; } - a2a_status = ucc_tl_mlx5_a2a_team_test(team); + a2a_status = ucc_tl_mlx5_alltoall_team_test(team); if (a2a_status < 0) { tl_warn(team->context->lib, "ALLTOALL tl team: %p creation failed %d", team, a2a_status); From 6e1ec374709db189a5b8ca73950c876bce37a079 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 9 Jan 2025 16:45:43 +0200 Subject: [PATCH 02/11] TL/MLX5: fix fences in WQEs --- src/components/tl/mlx5/tl_mlx5_wqe.c | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/components/tl/mlx5/tl_mlx5_wqe.c b/src/components/tl/mlx5/tl_mlx5_wqe.c index cf4d590658..cee82fd1f8 100644 --- a/src/components/tl/mlx5/tl_mlx5_wqe.c +++ b/src/components/tl/mlx5/tl_mlx5_wqe.c @@ -57,7 +57,7 @@ ucc_status_t ucc_tl_mlx5_post_transpose(struct ibv_qp *qp, uint32_t src_mr_lkey, uint32_t n_ds = 4; struct ibv_qp_ex * qp_ex = ibv_qp_to_qp_ex(qp); struct mlx5dv_qp_ex * mqp = mlx5dv_qp_ex_from_ibv_qp_ex(qp_ex); - int fm_ce_se = 0; + int fm_ce_se = MLX5_WQE_CTRL_INITIATOR_SMALL_FENCE; char wqe_desc[n_ds * DS_SIZE]; struct mlx5_wqe_ctrl_seg *ctrl; struct mlx5_wqe_data_seg *data; @@ -153,8 +153,7 @@ ucc_status_t ucc_tl_mlx5_post_umr(struct ibv_qp * qp, sizeof(struct mlx5_wqe_mkey_context_seg) + sizeof(struct mlx5_wqe_umr_pointer_seg)) / DS_SIZE; - uint8_t fm_ce_se = - MLX5_WQE_CTRL_INITIATOR_SMALL_FENCE | MLX5_WQE_CTRL_CQ_UPDATE; + uint8_t fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; struct ibv_qp_ex * qp_ex = ibv_qp_to_qp_ex(qp); struct mlx5dv_qp_ex * mqp = mlx5dv_qp_ex_from_ibv_qp_ex(qp_ex); struct mlx5_wqe_ctrl_seg * ctrl; @@ -275,7 +274,7 @@ ucc_status_t ucc_tl_mlx5_post_wait_on_data(struct ibv_qp *qp, uint64_t value, uint32_t n_ds = 3; //CTRL + Wait on Data of Size 2 struct ibv_qp_ex * qp_ex = ibv_qp_to_qp_ex(qp); struct mlx5dv_qp_ex *mqp = mlx5dv_qp_ex_from_ibv_qp_ex(qp_ex); - uint8_t fm_ce_se = MLX5_WQE_CTRL_FENCE | MLX5_WQE_CTRL_CQ_UPDATE; + uint8_t fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; char wqe_desc[n_ds * DS_SIZE]; struct mlx5_wqe_ctrl_seg *ctrl; wait_on_data_seg_t * wseg; From a84a08a408a19c38e038f197742c9f27663eba70 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 9 Jan 2025 16:49:04 +0200 Subject: [PATCH 03/11] CODESTYLE: align --- src/components/tl/mlx5/tl_mlx5_wqe.c | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/components/tl/mlx5/tl_mlx5_wqe.c b/src/components/tl/mlx5/tl_mlx5_wqe.c index cee82fd1f8..0794b6df22 100644 --- a/src/components/tl/mlx5/tl_mlx5_wqe.c +++ b/src/components/tl/mlx5/tl_mlx5_wqe.c @@ -153,13 +153,14 @@ ucc_status_t ucc_tl_mlx5_post_umr(struct ibv_qp * qp, sizeof(struct mlx5_wqe_mkey_context_seg) + sizeof(struct mlx5_wqe_umr_pointer_seg)) / DS_SIZE; - uint8_t fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; - struct ibv_qp_ex * qp_ex = ibv_qp_to_qp_ex(qp); - struct mlx5dv_qp_ex * mqp = mlx5dv_qp_ex_from_ibv_qp_ex(qp_ex); - struct mlx5_wqe_ctrl_seg * ctrl; - struct mlx5_wqe_umr_ctrl_seg * umr_ctrl_seg; + uint8_t fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; + struct ibv_qp_ex *qp_ex = ibv_qp_to_qp_ex(qp); + struct mlx5dv_qp_ex *mqp = + mlx5dv_qp_ex_from_ibv_qp_ex(qp_ex); + struct mlx5_wqe_ctrl_seg *ctrl; + struct mlx5_wqe_umr_ctrl_seg *umr_ctrl_seg; struct mlx5_wqe_mkey_context_seg *mk_seg; - struct mlx5_wqe_umr_pointer_seg * pseg; + struct mlx5_wqe_umr_pointer_seg *pseg; char wqe_desc[n_ds * DS_SIZE]; int xlat_size; @@ -269,12 +270,12 @@ ucc_status_t ucc_tl_mlx5_post_wait_on_data(struct ibv_qp *qp, uint64_t value, void *task_ptr) { - uint32_t opcode = MLX5_OPCODE_WAIT; - uint32_t opmode = 0x1; //wait on data - uint32_t n_ds = 3; //CTRL + Wait on Data of Size 2 - struct ibv_qp_ex * qp_ex = ibv_qp_to_qp_ex(qp); - struct mlx5dv_qp_ex *mqp = mlx5dv_qp_ex_from_ibv_qp_ex(qp_ex); - uint8_t fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; + uint32_t opcode = MLX5_OPCODE_WAIT; + uint32_t opmode = 0x1; //wait on data + uint32_t n_ds = 3; //CTRL + Wait on Data of Size 2 + struct ibv_qp_ex *qp_ex = ibv_qp_to_qp_ex(qp); + struct mlx5dv_qp_ex *mqp = mlx5dv_qp_ex_from_ibv_qp_ex(qp_ex); + uint8_t fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; char wqe_desc[n_ds * DS_SIZE]; struct mlx5_wqe_ctrl_seg *ctrl; wait_on_data_seg_t * wseg; From c226e9f8c50337685f2ada03775fdb82ccb51789 Mon Sep 17 00:00:00 2001 From: snordmann Date: Fri, 10 Jan 2025 14:40:36 +0200 Subject: [PATCH 04/11] TL/MLX5: further WQE flags revision --- src/components/tl/mlx5/alltoall/alltoall_inline.h | 2 +- src/components/tl/mlx5/tl_mlx5_wqe.c | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/components/tl/mlx5/alltoall/alltoall_inline.h b/src/components/tl/mlx5/alltoall/alltoall_inline.h index 02d82acb35..a1d081685f 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_inline.h +++ b/src/components/tl/mlx5/alltoall/alltoall_inline.h @@ -73,7 +73,7 @@ static inline ucc_status_t send_atomic(ucc_tl_mlx5_alltoall_t *a2a, struct mlx5dv_qp_ex *qp_dv; qp_ex = tl_mlx5_get_qp_ex(a2a, rank); - qp_ex->wr_flags = 0; + qp_ex->wr_flags = IBV_SEND_FENCE; ibv_wr_atomic_fetch_add(qp_ex, rkey, (uintptr_t)remote_addr, 1ULL); if (a2a->is_dc) { qp_dv = mlx5dv_qp_ex_from_ibv_qp_ex(qp_ex); diff --git a/src/components/tl/mlx5/tl_mlx5_wqe.c b/src/components/tl/mlx5/tl_mlx5_wqe.c index 0794b6df22..d0ece52902 100644 --- a/src/components/tl/mlx5/tl_mlx5_wqe.c +++ b/src/components/tl/mlx5/tl_mlx5_wqe.c @@ -57,7 +57,7 @@ ucc_status_t ucc_tl_mlx5_post_transpose(struct ibv_qp *qp, uint32_t src_mr_lkey, uint32_t n_ds = 4; struct ibv_qp_ex * qp_ex = ibv_qp_to_qp_ex(qp); struct mlx5dv_qp_ex * mqp = mlx5dv_qp_ex_from_ibv_qp_ex(qp_ex); - int fm_ce_se = MLX5_WQE_CTRL_INITIATOR_SMALL_FENCE; + int fm_ce_se = 0; char wqe_desc[n_ds * DS_SIZE]; struct mlx5_wqe_ctrl_seg *ctrl; struct mlx5_wqe_data_seg *data; From 7f9f604ba9f515d55f1f80f99f49e645f966e0cb Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 15 Jan 2025 13:07:25 +0200 Subject: [PATCH 05/11] TL/MLX5: remove unnecessary fence on atomic WQE --- src/components/tl/mlx5/alltoall/alltoall_inline.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/components/tl/mlx5/alltoall/alltoall_inline.h b/src/components/tl/mlx5/alltoall/alltoall_inline.h index a1d081685f..02d82acb35 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_inline.h +++ b/src/components/tl/mlx5/alltoall/alltoall_inline.h @@ -73,7 +73,7 @@ static inline ucc_status_t send_atomic(ucc_tl_mlx5_alltoall_t *a2a, struct mlx5dv_qp_ex *qp_dv; qp_ex = tl_mlx5_get_qp_ex(a2a, rank); - qp_ex->wr_flags = IBV_SEND_FENCE; + qp_ex->wr_flags = 0; ibv_wr_atomic_fetch_add(qp_ex, rkey, (uintptr_t)remote_addr, 1ULL); if (a2a->is_dc) { qp_dv = mlx5dv_qp_ex_from_ibv_qp_ex(qp_ex); From 4a3ffcadc751c1fe6cd11289b23dd9085e59073e Mon Sep 17 00:00:00 2001 From: Mamzi Bayatpour Date: Tue, 8 Oct 2024 23:24:50 +0300 Subject: [PATCH 06/11] TL/MLX5: add nonblocking cudaMemcpy support --- src/components/tl/mlx5/mcast/tl_mlx5_mcast.h | 17 ++++++++++ .../tl/mlx5/mcast/tl_mlx5_mcast_coll.c | 10 +++++- .../tl/mlx5/mcast/tl_mlx5_mcast_coll.h | 1 + .../tl/mlx5/mcast/tl_mlx5_mcast_progress.c | 34 +++++++++++++------ src/components/tl/mlx5/tl_mlx5_coll.c | 8 ++--- 5 files changed, 54 insertions(+), 16 deletions(-) diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 663ee636ed..3772f55616 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -441,6 +441,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_req { ucc_service_coll_req_t *allgather_rkeys_req; ucc_service_coll_req_t *barrier_req; void *recv_rreg; + ucc_ee_executor_task_t *exec_task; + ucc_coll_task_t *coll_task; } ucc_tl_mlx5_mcast_coll_req_t; typedef struct ucc_tl_mlx5_mcast_oob_p2p_context { @@ -555,6 +557,21 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_user_recv_buffers(ucc_tl_mlx5_ return UCC_OK; } +#define EXEC_TASK_TEST(_errmsg, _etask, _lib) do { \ + if (_etask != NULL) { \ + status = ucc_ee_executor_task_test(_etask); \ + if (status > 0) { \ + return status; \ + } \ + ucc_ee_executor_task_finalize(_etask); \ + _etask = NULL; \ + if (ucc_unlikely(status < 0)) { \ + tl_error(_lib, _errmsg); \ + return status; \ + } \ + } \ +} while(0) + ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context, ucc_tl_mlx5_mcast_team_t **mcast_team, ucc_tl_mlx5_mcast_context_t *ctx, diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index b6fbe84e3d..cf813fd5af 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -33,6 +33,10 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_ return status; } + while (req->exec_task != NULL) { + EXEC_TASK_TEST("failed to complete the nb memcpy", req->exec_task, comm->lib); + } + comm->bcast_comm.n_mcast_reliable++; for (; comm->bcast_comm.last_acked < comm->psn; comm->bcast_comm.last_acked++) { @@ -267,7 +271,10 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) return ucc_task_complete(coll_task); } - coll_task->status = status; + ucc_assert(task->coll_mcast.req_handle != NULL); + + coll_task->status = status; + task->coll_mcast.req_handle->coll_task = coll_task; return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super); } @@ -333,6 +340,7 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task) { task->super.post = ucc_tl_mlx5_mcast_bcast_start; task->super.progress = ucc_tl_mlx5_mcast_collective_progress; + task->super.flags = UCC_COLL_TASK_FLAG_EXECUTOR; return UCC_OK; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h index f34e8827f4..ccc563ecc7 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h @@ -16,4 +16,5 @@ ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req); ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team); + #endif diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c index 3620cf629f..8031af6dc0 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c @@ -391,9 +391,10 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com ucc_tl_mlx5_mcast_coll_req_t *req, struct pp_packet* pp) { - ucc_status_t status = UCC_OK; - void *dest; - ucc_memory_type_t mem_type; + ucc_status_t status = UCC_OK; + void *dest; + ucc_ee_executor_task_args_t eargs; + ucc_ee_executor_t *exec; ucc_assert(pp->psn >= req->start_psn && pp->psn < req->start_psn + req->num_packets); @@ -402,19 +403,30 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com if (pp->length > 0 ) { dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm); - - if (comm->cuda_mem_enabled) { - mem_type = UCC_MEMORY_TYPE_CUDA; - } else { - mem_type = UCC_MEMORY_TYPE_HOST; + while (req->exec_task != NULL) { + EXEC_TASK_TEST("failed to complete the nb memcpy", req->exec_task, comm->lib); } - status = ucc_mc_memcpy(dest, (void*) pp->buf, pp->length, - mem_type, mem_type); + /* for cuda copy, exec is nonblocking but for host copy it is blocking */ + status = ucc_coll_task_get_executor(req->coll_task, &exec); if (ucc_unlikely(status != UCC_OK)) { - tl_error(comm->lib, "failed to copy buffer"); return status; } + + eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; + eargs.copy.src = (void*) pp->buf; + eargs.copy.dst = dest; + eargs.copy.len = pp->length; + + assert(req->exec_task == NULL); + status = ucc_ee_executor_task_post(exec, &eargs, &req->exec_task); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + + if (req->exec_task != NULL) { + EXEC_TASK_TEST("failed to progress the memcpy", req->exec_task, comm->lib); + } } comm->r_window[pp->psn & (comm->bcast_comm.wsize-1)] = pp; diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c index 94d336ba6e..aabdbf8010 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.c +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -14,8 +14,8 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task_h) { - ucc_status_t status = UCC_OK; - ucc_tl_mlx5_task_t *task = NULL; + ucc_status_t status = UCC_OK; + ucc_tl_mlx5_task_t *task = NULL; status = ucc_tl_mlx5_mcast_check_support(coll_args, team); if (UCC_OK != status) { @@ -35,12 +35,14 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, if (ucc_unlikely(UCC_OK != status)) { goto free_task; } + *task_h = &(task->super); break; case UCC_COLL_TYPE_ALLGATHER: status = ucc_tl_mlx5_mcast_allgather_init(task); if (ucc_unlikely(UCC_OK != status)) { goto free_task; } + *task_h = &(task->super); break; default: status = UCC_ERR_NOT_SUPPORTED; @@ -48,8 +50,6 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, goto free_task; } - *task_h = &(task->super); - tl_debug(UCC_TASK_LIB(task), "initialized mcast collective task %p", task); return UCC_OK; From 3e2633dddd9e257a8bc780bc20404a21e3a22278 Mon Sep 17 00:00:00 2001 From: Mamzi Bayatpour Date: Thu, 19 Dec 2024 12:04:52 -0800 Subject: [PATCH 07/11] TL/MLX5: mcast multi-group support part 1 --- src/components/tl/mlx5/mcast/tl_mlx5_mcast.h | 31 ++- .../tl/mlx5/mcast/tl_mlx5_mcast_helper.c | 206 +++++++++++------- .../tl/mlx5/mcast/tl_mlx5_mcast_helper.h | 21 +- .../tl_mlx5_mcast_one_sided_reliability.c | 4 +- .../tl/mlx5/mcast/tl_mlx5_mcast_team.c | 34 +-- src/components/tl/mlx5/tl_mlx5_team.c | 2 +- 6 files changed, 166 insertions(+), 132 deletions(-) diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 3772f55616..49f2292166 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -199,22 +199,25 @@ struct pp_packet { uintptr_t buf; // buffer address, initialized once }; +struct mcast_group { + struct ibv_qp *qp; + struct ibv_ah *ah; + uint16_t lid; + union ibv_gid mgid; + struct sockaddr_in6 mcast_addr; +}; + struct mcast_ctx { - struct ibv_qp *qp; - struct ibv_ah *ah; struct ibv_send_wr swr; struct ibv_sge ssg; - + struct ibv_cq *scq; + struct ibv_cq *rcq; + struct ibv_srq *srq; + struct mcast_group groups[MAX_GROUP_COUNT]; // RC connection info for supporing one-sided based relibality struct ibv_qp **rc_qp; uint16_t *rc_lid; union ibv_gid *rc_gid; - - // multiple mcast group - struct ibv_qp **qp_list; - struct ibv_ah **ah_list; - struct ibv_send_wr *swr_list; - struct ibv_sge *ssg_list; }; struct packet { @@ -303,15 +306,10 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { ucc_tl_mlx5_mcast_coll_comm_init_spec_t params; ucc_tl_mlx5_mcast_p2p_interface_t p2p; int tx; - struct ibv_cq *scq; - struct ibv_cq *rcq; - struct ibv_srq *srq; ucc_rank_t rank; ucc_rank_t commsize; char *grh_buf; struct ibv_mr *grh_mr; - uint16_t mcast_lid; - union ibv_gid mgid; unsigned max_inline; size_t max_eager; int max_per_packet; @@ -334,7 +332,6 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { int comm_id; void *p2p_ctx; ucc_base_lib_t *lib; - struct sockaddr_in6 mcast_addr; int cuda_mem_enabled; ucc_tl_mlx5_mcast_join_info_t *group_setup_info; ucc_service_coll_req_t *group_setup_info_req; @@ -492,7 +489,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast } if (i != 0) { rwr[i-1].next = NULL; - if (ibv_post_recv(comm->mcast.qp, &rwr[0], &bad_wr)) { + if (ibv_post_recv(comm->mcast.groups[0].qp, &rwr[0], &bad_wr)) { tl_error(comm->lib, "failed to prepost recvs: errno %d", errno); return UCC_ERR_NO_RESOURCE; } @@ -545,7 +542,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_user_recv_buffers(ucc_tl_mlx5_ if (i > 0) { rwr[i-1].next = NULL; - if (ibv_post_recv(comm->mcast.qp_list[group_id], &rwr[0], &bad_wr)) { + if (ibv_post_recv(comm->mcast.groups[group_id].qp, &rwr[0], &bad_wr)) { tl_error(comm->lib, "Failed to prepost recvs: errno %d buffer count %d", errno, i); return UCC_ERR_NO_RESOURCE; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c index a116c08cf8..ae736a37a9 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c @@ -282,11 +282,14 @@ ucc_status_t ucc_tl_mlx5_setup_mcast_group_join_post(ucc_tl_mlx5_mcast_coll_comm ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, ucc_tl_mlx5_mcast_coll_comm_t *comm) { - struct ibv_qp_init_attr qp_init_attr = {0}; + int max_inline = INT_MAX; + struct ibv_qp_init_attr qp_init_attr = {0}; + int i; + int j; qp_init_attr.qp_type = IBV_QPT_UD; - qp_init_attr.send_cq = comm->scq; - qp_init_attr.recv_cq = comm->rcq; + qp_init_attr.send_cq = comm->mcast.scq; //cq can be shared between multiple QPs + qp_init_attr.recv_cq = comm->mcast.rcq; qp_init_attr.sq_sig_all = 0; qp_init_attr.cap.max_send_wr = comm->params.sx_depth; qp_init_attr.cap.max_recv_wr = comm->params.rx_depth; @@ -294,41 +297,68 @@ ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, qp_init_attr.cap.max_send_sge = comm->params.sx_sge; qp_init_attr.cap.max_recv_sge = comm->params.rx_sge; - comm->mcast.qp = ibv_create_qp(ctx->pd, &qp_init_attr); - if (!comm->mcast.qp) { - tl_warn(ctx->lib, "failed to create mcast qp, errno %d", errno); - return UCC_ERR_NO_RESOURCE; + for (i = 0; i < comm->mcast_group_count; i++) { + comm->mcast.groups[i].qp = ibv_create_qp(ctx->pd, &qp_init_attr); + if (!comm->mcast.groups[i].qp) { + tl_error(ctx->lib, "Failed to create mcast UD qp index %d, errno %d", i, errno); + goto error; + } + if (qp_init_attr.cap.max_inline_data < max_inline) { + max_inline = qp_init_attr.cap.max_inline_data; + } } if (comm->cuda_mem_enabled) { /* max inline send otherwise it segfault during ibv send */ comm->max_inline = 0; } else { - comm->max_inline = qp_init_attr.cap.max_inline_data; + comm->max_inline = max_inline; } return UCC_OK; + +error: + for (j = 0; j < i; j++) { + ibv_destroy_qp(comm->mcast.groups[j].qp); + comm->mcast.groups[j].qp = NULL; + } + return UCC_ERR_NO_RESOURCE; } static ucc_status_t ucc_tl_mlx5_mcast_create_ah(ucc_tl_mlx5_mcast_coll_comm_t *comm) { + int i, j, ret; struct ibv_ah_attr ah_attr = { .is_global = 1, .grh = {.sgid_index = 0}, - .dlid = comm->mcast_lid, .sl = DEF_SL, .src_path_bits = DEF_SRC_PATH_BITS, .port_num = comm->ctx->ib_port }; - memcpy(ah_attr.grh.dgid.raw, &comm->mgid, sizeof(ah_attr.grh.dgid.raw)); + for (i = 0; i < comm->mcast_group_count; i ++) { + ah_attr.dlid = comm->mcast.groups[i].lid; + memcpy(ah_attr.grh.dgid.raw, &comm->mcast.groups[i].mgid, sizeof(ah_attr.grh.dgid.raw)); - comm->mcast.ah = ibv_create_ah(comm->ctx->pd, &ah_attr); - if (!comm->mcast.ah) { - tl_warn(comm->lib, "failed to create AH"); - return UCC_ERR_NO_RESOURCE; + comm->mcast.groups[i].ah = ibv_create_ah(comm->ctx->pd, &ah_attr); + if (!comm->mcast.groups[i].ah) { + tl_error(comm->lib, "failed to create AH index %d", i); + goto error; + } } + return UCC_OK; + +error: + for (j = 0; j < i; j++) { + ret = ibv_destroy_ah(comm->mcast.groups[j].ah); + if (ret) { + tl_error(comm->lib, "couldn't destroy ah"); + return UCC_ERR_NO_RESOURCE; + } + comm->mcast.groups[j].ah = NULL; + } + return UCC_ERR_NO_RESOURCE; } ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, @@ -337,16 +367,15 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, struct ibv_port_attr port_attr; struct ibv_qp_attr attr; uint16_t pkey; + int i; ibv_query_port(ctx->ctx, ctx->ib_port, &port_attr); - for (ctx->pkey_index = 0; ctx->pkey_index < port_attr.pkey_tbl_len; ++ctx->pkey_index) { ibv_query_pkey(ctx->ctx, ctx->ib_port, ctx->pkey_index, &pkey); if (pkey == DEF_PKEY) break; } - if (ctx->pkey_index >= port_attr.pkey_tbl_len) { ctx->pkey_index = 0; ibv_query_pkey(ctx->ctx, ctx->ib_port, ctx->pkey_index, &pkey); @@ -359,43 +388,53 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, "index 0 pkey:0x%04x", DEF_PKEY, ctx->ib_port, pkey); } - attr.qp_state = IBV_QPS_INIT; - attr.pkey_index = ctx->pkey_index; - attr.port_num = ctx->ib_port; - attr.qkey = DEF_QKEY; + for (i = 0; i < comm->mcast_group_count; i++) { + attr.qp_state = IBV_QPS_INIT; + attr.pkey_index = ctx->pkey_index; + attr.port_num = ctx->ib_port; + attr.qkey = DEF_QKEY; - if (ibv_modify_qp(comm->mcast.qp, &attr, - IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_QKEY)) { - tl_warn(ctx->lib, "failed to move mcast qp to INIT, errno %d", errno); - return UCC_ERR_NO_RESOURCE; - } + if (ibv_modify_qp(comm->mcast.groups[i].qp, &attr, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_QKEY)) { + tl_error(ctx->lib, "failed to move mcast qp to INIT, errno %d", errno); + goto error; + } - if (ibv_attach_mcast(comm->mcast.qp, &comm->mgid, comm->mcast_lid)) { - tl_warn(ctx->lib, "failed to attach QP to the mcast group, errno %d", errno); - return UCC_ERR_NO_RESOURCE; - } + if (ibv_attach_mcast(comm->mcast.groups[i].qp, &comm->mcast.groups[i].mgid, + comm->mcast.groups[i].lid)) { + tl_error(ctx->lib, "failed to attach QP to the mcast group with mcast_lid %d , errno %d", + errno, comm->mcast.groups[i].lid); + goto error; + } - /* Ok, now cycle to RTR on everyone */ - attr.qp_state = IBV_QPS_RTR; - if (ibv_modify_qp(comm->mcast.qp, &attr, IBV_QP_STATE)) { - tl_warn(ctx->lib, "failed to modify QP to RTR, errno %d", errno); - return UCC_ERR_NO_RESOURCE; - } + attr.qp_state = IBV_QPS_RTR; + if (ibv_modify_qp(comm->mcast.groups[i].qp, &attr, IBV_QP_STATE)) { + tl_error(ctx->lib, "failed to modify QP to RTR, errno %d", errno); + goto error; + } - attr.qp_state = IBV_QPS_RTS; - attr.sq_psn = DEF_PSN; - if (ibv_modify_qp(comm->mcast.qp, &attr, IBV_QP_STATE | IBV_QP_SQ_PSN)) { - tl_warn(ctx->lib, "failed to modify QP to RTS, errno %d", errno); - return UCC_ERR_NO_RESOURCE; + attr.qp_state = IBV_QPS_RTS; + attr.sq_psn = DEF_PSN; + if (ibv_modify_qp(comm->mcast.groups[i].qp, &attr, IBV_QP_STATE | IBV_QP_SQ_PSN)) { + tl_error(ctx->lib, "failed to modify QP to RTS, errno %d", errno); + goto error; + } } - /* Create the address handle */ + /* create the address handle */ if (UCC_OK != ucc_tl_mlx5_mcast_create_ah(comm)) { tl_warn(ctx->lib, "failed to create adress handle"); - return UCC_ERR_NO_RESOURCE; + goto error; } return UCC_OK; + +error: + for (i=0; i < comm->mcast_group_count; i++) { + ibv_destroy_qp(comm->mcast.groups[i].qp); + comm->mcast.groups[i].qp = NULL; + } + return UCC_ERR_NO_RESOURCE; } ucc_status_t ucc_tl_mlx5_mcast_create_rc_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, @@ -410,8 +449,8 @@ ucc_status_t ucc_tl_mlx5_mcast_create_rc_qps(ucc_tl_mlx5_mcast_coll_context_t *c srq_init_attr.attr.max_wr = comm->params.rx_depth; srq_init_attr.attr.max_sge = 2; - comm->srq = ibv_create_srq(ctx->pd, &srq_init_attr); - if (!comm->srq) { + comm->mcast.srq = ibv_create_srq(ctx->pd, &srq_init_attr); + if (!comm->mcast.srq) { tl_error(ctx->lib, "ibv_create_srq() failed"); return UCC_ERR_NO_RESOURCE; } @@ -426,10 +465,10 @@ ucc_status_t ucc_tl_mlx5_mcast_create_rc_qps(ucc_tl_mlx5_mcast_coll_context_t *c for (i = 0; i < comm->commsize; i++) { memset(&qp_init_attr, 0, sizeof(qp_init_attr)); - qp_init_attr.srq = comm->srq; + qp_init_attr.srq = comm->mcast.srq; qp_init_attr.qp_type = IBV_QPT_RC; - qp_init_attr.send_cq = comm->scq; - qp_init_attr.recv_cq = comm->rcq; + qp_init_attr.send_cq = comm->mcast.scq; + qp_init_attr.recv_cq = comm->mcast.rcq; qp_init_attr.sq_sig_all = 0; qp_init_attr.cap.max_send_wr = comm->params.sx_depth; qp_init_attr.cap.max_recv_wr = 0; // has srq @@ -454,7 +493,7 @@ ucc_status_t ucc_tl_mlx5_mcast_create_rc_qps(ucc_tl_mlx5_mcast_coll_context_t *c } } - if (ibv_destroy_srq(comm->srq)) { + if (ibv_destroy_srq(comm->mcast.srq)) { tl_error(comm->lib, "ibv_destroy_srq failed"); return UCC_ERR_NO_RESOURCE; } @@ -538,7 +577,7 @@ ucc_status_t ucc_tl_mlx5_fini_mcast_group(ucc_tl_mlx5_mcast_coll_context_t *ctx, char buf[40]; const char *dst; - dst = inet_ntop(AF_INET6, &comm->mcast_addr, buf, 40); + dst = inet_ntop(AF_INET6, &comm->mcast.groups[0].mcast_addr, buf, 40); if (NULL == dst) { tl_error(comm->lib, "inet_ntop failed"); return UCC_ERR_NO_RESOURCE; @@ -546,7 +585,7 @@ ucc_status_t ucc_tl_mlx5_fini_mcast_group(ucc_tl_mlx5_mcast_coll_context_t *ctx, tl_debug(ctx->lib, "mcast leave: ctx %p, comm %p, dgid: %s", ctx, comm, buf); - if (rdma_leave_multicast(ctx->id, (struct sockaddr*)&comm->mcast_addr)) { + if (rdma_leave_multicast(ctx->id, (struct sockaddr*)&comm->mcast.groups[0].mcast_addr)) { tl_error(comm->lib, "mcast rmda_leave_multicast failed"); return UCC_ERR_NO_RESOURCE; } @@ -559,11 +598,10 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) ucc_tl_mlx5_mcast_context_t *mcast_ctx = ucc_container_of(comm->ctx, ucc_tl_mlx5_mcast_context_t, mcast_context); ucc_tl_mlx5_context_t *mlx5_ctx = ucc_container_of(mcast_ctx, ucc_tl_mlx5_context_t, mcast); ucc_context_h context = mlx5_ctx->super.super.ucc_context; - int ret; + int ret, i; ucc_status_t status; - tl_debug(comm->lib, "cleaning mcast comm: %p, id %d, mlid %x", - comm, comm->comm_id, comm->mcast_lid); + tl_debug(comm->lib, "cleaning mcast comm: %p, id %d", comm, comm->comm_id); while (UCC_INPROGRESS == (status = ucc_tl_mlx5_mcast_reliable(comm))) { ucc_context_progress(context); @@ -575,32 +613,48 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) return status; } - if (comm->mcast.qp) { - ret = ibv_detach_mcast(comm->mcast.qp, &comm->mgid, comm->mcast_lid); - if (ret) { - tl_error(comm->lib, "couldn't detach QP, ret %d, errno %d", ret, errno); - return UCC_ERR_NO_RESOURCE; + for (i = 0; i < comm->mcast_group_count; i++) { + if (comm->mcast.groups[i].qp) { + ret = ibv_detach_mcast(comm->mcast.groups[i].qp, &(comm->mcast.groups[i].mgid), comm->mcast.groups[i].lid); + if (ret) { + tl_error(comm->lib, "couldn't detach QP, ret %d, errno %d", ret, errno); + return UCC_ERR_NO_RESOURCE; + } + + ret = ibv_destroy_qp(comm->mcast.groups[i].qp); + if (ret) { + tl_error(comm->lib, "failed to destroy QP %d", ret); + return UCC_ERR_NO_RESOURCE; + } + + comm->mcast.groups[i].qp = NULL; + } + if (comm->mcast.groups[i].ah) { + ret = ibv_destroy_ah(comm->mcast.groups[i].ah); + if (ret) { + tl_error(comm->lib, "couldn't destroy ah"); + return UCC_ERR_NO_RESOURCE; + } + comm->mcast.groups[i].ah = NULL; } } - if (comm->mcast.qp) { - ret = ibv_destroy_qp(comm->mcast.qp); - if (ret) { - tl_error(comm->lib, "failed to destroy QP %d", ret); - return UCC_ERR_NO_RESOURCE; - } + status = ucc_tl_mlx5_fini_mcast_group(comm->ctx, comm); + if (status) { + tl_error(comm->lib, "couldn't leave mcast group"); + return status; } - if (comm->rcq) { - ret = ibv_destroy_cq(comm->rcq); + if (comm->mcast.rcq) { + ret = ibv_destroy_cq(comm->mcast.rcq); if (ret) { tl_error(comm->lib, "couldn't destroy rcq"); return UCC_ERR_NO_RESOURCE; } } - if (comm->scq) { - ret = ibv_destroy_cq(comm->scq); + if (comm->mcast.scq) { + ret = ibv_destroy_cq(comm->mcast.scq); if (ret) { tl_error(comm->lib, "couldn't destroy scq"); return UCC_ERR_NO_RESOURCE; @@ -643,22 +697,6 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) ucc_free(comm->call_rsgs); } - if (comm->mcast.ah) { - ret = ibv_destroy_ah(comm->mcast.ah); - if (ret) { - tl_error(comm->lib, "couldn't destroy ah"); - return UCC_ERR_NO_RESOURCE; - } - } - - if (comm->mcast_lid) { - status = ucc_tl_mlx5_fini_mcast_group(comm->ctx, comm); - if (status) { - tl_error(comm->lib, "couldn't leave mcast group"); - return status; - } - } - if (comm->ctx->params.print_nack_stats) { tl_debug(comm->lib, "comm_id %d, comm_size %d, comm->psn %d, rank %d, " "nacks counter %d, n_mcast_rel %d", diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h index d0b1a1ddd3..fc5296568d 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h @@ -16,7 +16,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_poll_send(ucc_tl_mlx5_mcast_coll_co struct ibv_wc wc; int num_comp; - num_comp = ibv_poll_cq(comm->scq, 1, &wc); + num_comp = ibv_poll_cq(comm->mcast.scq, 1, &wc); tl_trace(comm->lib, "polled send completions: %d", num_comp); @@ -108,7 +108,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t tl_trace(comm->lib, "post_send, psn %d, length %d, zcopy %d, signaled %d", pp->psn, pp->length, zcopy, swr[0].send_flags & IBV_SEND_SIGNALED); - if (0 != (rc = ibv_post_send(comm->mcast.qp, &swr[0], &bad_wr))) { + if (0 != (rc = ibv_post_send(comm->mcast.groups[0].qp, &swr[0], &bad_wr))) { tl_error(comm->lib, "post send failed: ret %d, start_psn %d, to_send %d, " "to_recv %d, length %d, psn %d, inline %d", rc, req->start_psn, req->to_send, req->to_recv, @@ -202,7 +202,7 @@ static inline int ucc_tl_mlx5_mcast_recv(ucc_tl_mlx5_mcast_coll_comm_t *comm, while (num_left > 0) { memset(wc, 0, sizeof(struct ibv_wc) * POLL_PACKED); - num_comp = ibv_poll_cq(comm->rcq, POLL_PACKED, wc); + num_comp = ibv_poll_cq(comm->mcast.rcq, POLL_PACKED, wc); if (num_comp < 0) { tl_error(comm->lib, "recv queue poll completion failed %d", num_comp); @@ -329,19 +329,19 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send_collective(ucc_tl_mlx5_mcast_c mcast_group_index = group_id; } - swr[0].wr.ud.ah = comm->mcast.ah_list[mcast_group_index]; + swr[0].wr.ud.ah = comm->mcast.groups[mcast_group_index].ah; tl_trace(comm->lib, "mcast allgather post_send, psn %d, length %d, " "zcopy %d, signaled %d qp->state %d qp->qp_num %d qp->pd %p " "coll_type %d mcast_group_index %d", pp->psn, pp->length, zcopy, swr[0].send_flags & IBV_SEND_SIGNALED, - comm->mcast.qp_list[mcast_group_index]->state, - comm->mcast.qp_list[mcast_group_index]->qp_num, - comm->mcast.qp_list[mcast_group_index]->pd, coll_type, + comm->mcast.groups[mcast_group_index].qp->state, + comm->mcast.groups[mcast_group_index].qp->qp_num, + comm->mcast.groups[mcast_group_index].qp->pd, coll_type, mcast_group_index); - if (0 != (rc = ibv_post_send(comm->mcast.qp_list[mcast_group_index], &swr[0], &bad_wr))) { + if (0 != (rc = ibv_post_send(comm->mcast.groups[mcast_group_index].qp, &swr[0], &bad_wr))) { tl_error(comm->lib, "post send failed: ret %d, start_psn %d, to_send %d, " "to_recv %d, length %d, psn %d, inline %d", rc, req->start_psn, req->to_send, req->to_recv, @@ -398,7 +398,7 @@ static inline int ucc_tl_mlx5_mcast_recv_collective(ucc_tl_mlx5_mcast_coll_comm_ while (num_left > recv_progressed) { memset(wc, 0, sizeof(sizeof(struct ibv_wc) * POLL_PACKED)); - num_comp = ibv_poll_cq(comm->rcq, POLL_PACKED, &wc[0]); + num_comp = ibv_poll_cq(comm->mcast.rcq, POLL_PACKED, &wc[0]); if (num_comp < 0) { tl_error(comm->lib, "recv queue poll completion failed %d", num_comp); @@ -460,10 +460,9 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_poll_recv(ucc_tl_mlx5_mcast_coll_co uint32_t psn; do { - num_comp = ibv_poll_cq(comm->rcq, 1, &wc); + num_comp = ibv_poll_cq(comm->mcast.rcq, 1, &wc); if (num_comp > 0) { - if (IBV_WC_SUCCESS != wc.status) { tl_error(comm->lib, "mcast_poll_recv: %s err %d num_comp", ibv_wc_status_str(wc.status), num_comp); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c index a3d16fd6d8..3db2d1a8f7 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c @@ -119,11 +119,11 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_one_sided_cleanup(ucc_tl_mlx5_mcast comm->mcast.rc_qp = NULL; } - if (comm->srq != NULL && ibv_destroy_srq(comm->srq)) { + if (comm->mcast.srq != NULL && ibv_destroy_srq(comm->mcast.srq)) { tl_error(comm->lib, "ibv_destroy_srq failed"); return UCC_ERR_NO_RESOURCE; } - comm->srq = NULL; + comm->mcast.srq = NULL; if (comm->one_sided.slots_mr) { ibv_dereg_mr(comm->one_sided.slots_mr); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index b70ca6e2f6..84efb5daf1 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -114,8 +114,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, goto cleanup; } - comm->rcq = ibv_create_cq(mcast_context->ctx, comm->params.rx_depth, NULL, NULL, 0); - if (!comm->rcq) { + comm->mcast.rcq = ibv_create_cq(mcast_context->ctx, comm->params.rx_depth, NULL, NULL, 0); + if (!comm->mcast.rcq) { ibv_dereg_mr(comm->grh_mr); tl_error(mcast_context->lib, "could not create recv cq, rx_depth %d, errno %d", comm->params.rx_depth, errno); @@ -123,10 +123,10 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, goto cleanup; } - comm->scq = ibv_create_cq(mcast_context->ctx, comm->params.sx_depth, NULL, NULL, 0); - if (!comm->scq) { + comm->mcast.scq = ibv_create_cq(mcast_context->ctx, comm->params.sx_depth, NULL, NULL, 0); + if (!comm->mcast.scq) { ibv_dereg_mr(comm->grh_mr); - ibv_destroy_cq(comm->rcq); + ibv_destroy_cq(comm->mcast.rcq); tl_error(mcast_context->lib, "could not create send cq, sx_depth %d, errno %d", comm->params.sx_depth, errno); status = UCC_ERR_NO_RESOURCE; @@ -263,7 +263,7 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_ ucc_list_add_tail(&comm->bpool, &comm->pp[i].super); } - comm->mcast.swr.wr.ud.ah = comm->mcast.ah; + comm->mcast.swr.wr.ud.ah = comm->mcast.groups[0].ah; comm->mcast.swr.num_sge = 1; comm->mcast.swr.sg_list = &comm->mcast.ssg; comm->mcast.swr.opcode = IBV_WR_SEND_WITH_IMM; @@ -325,8 +325,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) return UCC_INPROGRESS; } - comm->mcast_addr = net_addr; - tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; + comm->mcast.groups[0].mcast_addr = net_addr; + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; return UCC_INPROGRESS; } @@ -373,11 +373,11 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY) { /* rank 0 bcast the lid/gid to other processes */ - data->status = UCC_OK; - data->dgid = comm->event->param.ud.ah_attr.grh.dgid; - data->dlid = comm->event->param.ud.ah_attr.dlid; - comm->mcast_lid = data->dlid; - comm->mgid = data->dgid; + data->status = UCC_OK; + data->dgid = comm->event->param.ud.ah_attr.grh.dgid; + data->dlid = comm->event->param.ud.ah_attr.dlid; + comm->mcast.groups[0].lid = data->dlid; + comm->mcast.groups[0].mgid = data->dgid; } else { /* rank 0 bcast the failed status to other processes so others do not hang */ data->status = UCC_ERR_NO_RESOURCE; @@ -522,8 +522,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) return status; } - comm->mcast_addr = net_addr; - tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; + comm->mcast.groups[0].mcast_addr = net_addr; + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; return UCC_INPROGRESS; } @@ -549,8 +549,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) ucc_assert(comm->event != NULL); - comm->mcast_lid = comm->group_setup_info->dlid; - comm->mgid = comm->group_setup_info->dgid; + comm->mcast.groups[0].lid = comm->group_setup_info->dlid; + comm->mcast.groups[0].mgid = comm->group_setup_info->dgid; ucc_free(comm->group_setup_info); if (comm->event) { diff --git a/src/components/tl/mlx5/tl_mlx5_team.c b/src/components/tl/mlx5/tl_mlx5_team.c index 6a65274d1d..e5cc29490a 100644 --- a/src/components/tl/mlx5/tl_mlx5_team.c +++ b/src/components/tl/mlx5/tl_mlx5_team.c @@ -198,7 +198,7 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) tl_warn(UCC_TL_TEAM_LIB(tl_team), "ibv_dereg_mr failed"); } - if (ibv_destroy_cq(comm->rcq)) { + if (ibv_destroy_cq(comm->mcast.rcq)) { tl_warn(UCC_TL_TEAM_LIB(tl_team), "ibv_destroy_cq failed"); } From 08e76398a0d7acf60a3509c331b2bc7e33f9d6c3 Mon Sep 17 00:00:00 2001 From: Ilya Kryukov Date: Tue, 11 Feb 2025 15:20:57 +0100 Subject: [PATCH 08/11] TL/CUDA: Linear Broadcast for GPU (#948) Adding linear CUDA Broadcast implementation with Active set feature support. It gives functional improvement, and parity with others communication libraries. - Ability to place many ranks on single GPU - No GPU blocking, communication initiated from host - Active set can be used to emulate P2P send/receive on top of broadcast collective --- src/components/tl/cuda/Makefile.am | 8 +- src/components/tl/cuda/allgather/allgather.c | 4 +- .../tl/cuda/allgather/allgather_linear.c | 4 +- .../tl/cuda/allgatherv/allgatherv.c | 4 +- .../tl/cuda/allgatherv/allgatherv_linear.c | 20 +- src/components/tl/cuda/bcast/bcast.c | 28 ++ src/components/tl/cuda/bcast/bcast.h | 43 ++ src/components/tl/cuda/bcast/bcast_linear.c | 419 ++++++++++++++++++ .../tl/cuda/reduce_scatter/reduce_scatter.c | 4 +- .../reduce_scatter/reduce_scatter_linear.c | 4 +- .../tl/cuda/reduce_scatterv/reduce_scatterv.c | 4 +- .../reduce_scatterv/reduce_scatterv_linear.c | 20 +- src/components/tl/cuda/tl_cuda.c | 5 +- src/components/tl/cuda/tl_cuda.h | 38 +- src/components/tl/cuda/tl_cuda_coll.c | 31 +- src/components/tl/cuda/tl_cuda_coll.h | 88 +++- src/components/tl/cuda/tl_cuda_ring.h | 18 +- src/components/tl/cuda/tl_cuda_team.c | 31 +- src/components/tl/cuda/tl_cuda_team_topo.h | 4 +- src/components/tl/ucp/bcast/bcast_knomial.c | 4 +- test/gtest/coll/test_bcast.cc | 10 +- 21 files changed, 681 insertions(+), 110 deletions(-) create mode 100644 src/components/tl/cuda/bcast/bcast.c create mode 100644 src/components/tl/cuda/bcast/bcast.h create mode 100644 src/components/tl/cuda/bcast/bcast_linear.c diff --git a/src/components/tl/cuda/Makefile.am b/src/components/tl/cuda/Makefile.am index e22796e6fa..65fb41ca1f 100644 --- a/src/components/tl/cuda/Makefile.am +++ b/src/components/tl/cuda/Makefile.am @@ -1,5 +1,5 @@ # -# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) Meta Platforms, Inc. and affiliates. 2022. # @@ -27,6 +27,11 @@ alltoallv = \ alltoallv/alltoallv.c \ alltoallv/alltoallv_ce.c +bcast = \ + bcast/bcast.h \ + bcast/bcast.c \ + bcast/bcast_linear.c + reduce_scatter = \ reduce_scatter/reduce_scatter.h \ reduce_scatter/reduce_scatter.c \ @@ -54,6 +59,7 @@ sources = \ $(allgatherv) \ $(alltoall) \ $(alltoallv) \ + $(bcast) \ $(reduce_scatter) \ $(reduce_scatterv) diff --git a/src/components/tl/cuda/allgather/allgather.c b/src/components/tl/cuda/allgather/allgather.c index 01996da4da..362191b3ac 100644 --- a/src/components/tl/cuda/allgather/allgather.c +++ b/src/components/tl/cuda/allgather/allgather.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -44,7 +44,7 @@ ucc_status_t ucc_tl_cuda_allgather_init(ucc_base_coll_args_t *coll_args, { ucc_tl_cuda_team_t *team = ucc_derived_of(tl_team, ucc_tl_cuda_team_t); - if (ucc_tl_cuda_team_topo_is_fully_conntected(team->topo)) { + if (ucc_tl_cuda_team_topo_is_fully_connected(team->topo)) { return ucc_tl_cuda_allgather_linear_init(coll_args, tl_team, task_p); } else { return ucc_tl_cuda_allgather_ring_init(coll_args, tl_team, task_p); diff --git a/src/components/tl/cuda/allgather/allgather_linear.c b/src/components/tl/cuda/allgather/allgather_linear.c index ed228d1683..d0b416257b 100644 --- a/src/components/tl/cuda/allgather/allgather_linear.c +++ b/src/components/tl/cuda/allgather/allgather_linear.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -15,7 +15,7 @@ ucc_status_t ucc_tl_cuda_allgather_linear_init(ucc_base_coll_args_t *coll_args, ucc_tl_cuda_task_t *task; ucc_status_t status; - if (ucc_unlikely(!ucc_tl_cuda_team_topo_is_fully_conntected(team->topo) || + if (ucc_unlikely(!ucc_tl_cuda_team_topo_is_fully_connected(team->topo) || UCC_TL_TEAM_SIZE(team) - 1 > UCC_EE_EXECUTOR_MULTI_OP_NUM_BUFS)) { return UCC_ERR_NOT_SUPPORTED; } diff --git a/src/components/tl/cuda/allgatherv/allgatherv.c b/src/components/tl/cuda/allgatherv/allgatherv.c index 5a8f78c481..4a73bbdf08 100644 --- a/src/components/tl/cuda/allgatherv/allgatherv.c +++ b/src/components/tl/cuda/allgatherv/allgatherv.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -47,7 +47,7 @@ ucc_status_t ucc_tl_cuda_allgatherv_init(ucc_base_coll_args_t *coll_args, { ucc_tl_cuda_team_t *team = ucc_derived_of(tl_team, ucc_tl_cuda_team_t); - if (ucc_tl_cuda_team_topo_is_fully_conntected(team->topo)) { + if (ucc_tl_cuda_team_topo_is_fully_connected(team->topo)) { return ucc_tl_cuda_allgatherv_linear_init(coll_args, tl_team, task_p); } else { return ucc_tl_cuda_allgatherv_ring_init(coll_args, tl_team, task_p); diff --git a/src/components/tl/cuda/allgatherv/allgatherv_linear.c b/src/components/tl/cuda/allgatherv/allgatherv_linear.c index 0fca5c6af6..9a8b5db140 100644 --- a/src/components/tl/cuda/allgatherv/allgatherv_linear.c +++ b/src/components/tl/cuda/allgatherv/allgatherv_linear.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -55,22 +55,6 @@ enum * other ranks to finish */ }; -static inline int get_rank_step(ucc_tl_cuda_task_t *task, ucc_rank_t rank, - int step_id) -{ - ucc_tl_cuda_sync_t *sync = TASK_SYNC(task, rank); - - return sync->seq_num[step_id]; -} - -static inline void set_rank_step(ucc_tl_cuda_task_t *task, ucc_rank_t rank, - int step, int step_id) -{ - ucc_tl_cuda_sync_t *sync = TASK_SYNC(task, rank); - - sync->seq_num[step_id] = step; -} - ucc_status_t ucc_tl_cuda_allgatherv_linear_finalize(ucc_coll_task_t *coll_task) { ucc_tl_cuda_task_t *task = ucc_derived_of(coll_task, ucc_tl_cuda_task_t); @@ -432,7 +416,7 @@ ucc_status_t ucc_tl_cuda_allgatherv_linear_init(ucc_base_coll_args_t *coll_args, ucc_tl_cuda_task_t *task; ucc_status_t status; - if (ucc_unlikely(!ucc_tl_cuda_team_topo_is_fully_conntected(team->topo) || + if (ucc_unlikely(!ucc_tl_cuda_team_topo_is_fully_connected(team->topo) || UCC_TL_TEAM_SIZE(team) - 1 > UCC_EE_EXECUTOR_MULTI_OP_NUM_BUFS)) { return UCC_ERR_NOT_SUPPORTED; } diff --git a/src/components/tl/cuda/bcast/bcast.c b/src/components/tl/cuda/bcast/bcast.c new file mode 100644 index 0000000000..954cf86d9f --- /dev/null +++ b/src/components/tl/cuda/bcast/bcast.c @@ -0,0 +1,28 @@ +/** + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "bcast.h" +#include "components/mc/ucc_mc.h" + +ucc_base_coll_alg_info_t + ucc_tl_cuda_bcast_algs[UCC_TL_CUDA_BCAST_ALG_LAST + 1] = { + [UCC_TL_CUDA_BCAST_ALG_LINEAR] = {.id = UCC_TL_CUDA_BCAST_ALG_LINEAR, + .name = "linear", + .desc = "linear bcast algorithm"}, + [UCC_TL_CUDA_BCAST_ALG_LAST] = {.id = 0, .name = NULL, .desc = NULL}}; + +ucc_status_t ucc_tl_cuda_bcast_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *tl_team, + ucc_coll_task_t **task_p) +{ + ucc_tl_cuda_team_t *team = ucc_derived_of(tl_team, ucc_tl_cuda_team_t); + + if (ucc_tl_cuda_team_topo_is_fully_connected(team->topo)) { + return ucc_tl_cuda_bcast_linear_init(coll_args, tl_team, task_p); + } else { + return UCC_ERR_NOT_SUPPORTED; + } +} diff --git a/src/components/tl/cuda/bcast/bcast.h b/src/components/tl/cuda/bcast/bcast.h new file mode 100644 index 0000000000..5810bcc89d --- /dev/null +++ b/src/components/tl/cuda/bcast/bcast.h @@ -0,0 +1,43 @@ +/** + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#ifndef BCAST_H_ +#define BCAST_H_ + +#include "tl_cuda.h" +#include "tl_cuda_coll.h" + +enum +{ + UCC_TL_CUDA_BCAST_ALG_LINEAR, + UCC_TL_CUDA_BCAST_ALG_LAST +}; + +extern ucc_base_coll_alg_info_t + ucc_tl_cuda_bcast_algs[UCC_TL_CUDA_BCAST_ALG_LAST + 1]; + +#define UCC_TL_CUDA_BCAST_DEFAULT_ALG_SELECT_STR "bcast:cuda:@0" + +ucc_status_t ucc_tl_cuda_bcast_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *tl_team, + ucc_coll_task_t **task_p); + +ucc_status_t ucc_tl_cuda_bcast_linear_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *tl_team, + ucc_coll_task_t **task_p); + +static inline int ucc_tl_cuda_bcast_alg_from_str(const char *str) +{ + int i; + for (i = 0; i < UCC_TL_CUDA_BCAST_ALG_LAST; i++) { + if (0 == strcasecmp(str, ucc_tl_cuda_bcast_algs[i].name)) { + break; + } + } + return i; +} + +#endif diff --git a/src/components/tl/cuda/bcast/bcast_linear.c b/src/components/tl/cuda/bcast/bcast_linear.c new file mode 100644 index 0000000000..f865cd6711 --- /dev/null +++ b/src/components/tl/cuda/bcast/bcast_linear.c @@ -0,0 +1,419 @@ +/** + * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "bcast.h" + +enum { + // Barrier setup stages + STAGE_INIT_BAR_ROOT, // Initial stage for the root rank to identify and claim a free barrier + STAGE_FIND_BAR_PEER, // Stage where peer ranks wait while the root rank identifies a free barrier + + STAGE_SYNC, // Initialize the barrier and synchronize the segment required for the current task + STAGE_SETUP, // Verify that all ranks are aligned and have reached the barrier + // Stages specific to the root rank + STAGE_COPY, // Post copy task: copy data block from src to a scratch buffer + STAGE_WAIT_COPY, // The root waits for the completion of its copy operation + STAGE_WAIT_ALL, // The root rank waits until all other ranks have reached the same operational step + STAGE_WAIT_COMPLETION, // The root rank waits for all other ranks to complete the broadcast operation + // non-root + STAGE_WAIT_ROOT, // Wait while the root rank writes data to its scratch buffer + STAGE_CLIENT_COPY, // Initiate their own copy tasks after the root's operations + STAGE_CLIENT_COPY_WAIT, // Wait for the completion of the copy operation from the root's scratch buffer + STAGE_CLIENT_WAIT_COMPLETION, // Wait for the completion of algorithm on all ranks, global sync with root +}; + +static inline ucc_status_t +ucc_tl_cuda_bcast_linear_setup_start(ucc_tl_cuda_task_t *task) +{ + ucc_tl_cuda_team_t *team = TASK_TEAM(task); + ucc_rank_t trank = UCC_TL_TEAM_RANK(team); + + set_rank_step(task, trank, 0, 0); // Initialize rank step tracking + ucc_memory_cpu_store_fence(); + // initiate barrier wait while all ranks set theirs steps to 0 + return ucc_tl_cuda_shm_barrier_start(trank, task->bar); +} + +// Tests if setup is complete for a linear broadcast task +static inline ucc_status_t +ucc_tl_cuda_bcast_linear_setup_test(ucc_tl_cuda_task_t *task) +{ + ucc_tl_cuda_team_t *team = TASK_TEAM(task); + return ucc_tl_cuda_shm_barrier_test(UCC_TL_TEAM_RANK(team), task->bar); +} + +// Returns the size of the scratch buffer used for data transfers +static inline size_t get_raw_scratch_size(ucc_tl_cuda_team_t *team) +{ + return UCC_TL_CUDA_TEAM_LIB(team)->cfg.scratch_size; +} + +// Posts a copy task to the CUDA executor +static inline ucc_status_t ecopy(void *dst, void *src, size_t size, + ucc_ee_executor_t *exec, + ucc_ee_executor_task_t **etask) +{ + ucc_ee_executor_task_args_t exec_args = {0}; + + exec_args.task_type = UCC_EE_EXECUTOR_TASK_COPY; + exec_args.copy.dst = dst; + exec_args.copy.src = src; + exec_args.copy.len = size; + return ucc_ee_executor_task_post(exec, &exec_args, etask); +} + +// Root rank searches for and claims a free barrier +static inline ucc_status_t root_find_free_barrier(ucc_tl_cuda_task_t *task) +{ + ucc_tl_cuda_team_t *team = TASK_TEAM(task); + uint32_t max_concurrent = UCC_TL_CUDA_TEAM_LIB(team)->cfg.max_concurrent; + ucc_tl_cuda_shm_barrier_t *curr_bar; + int i; + ucc_status_t st; + + // Iterate over available barriers in active set pool to find a free one + for (i = 0; i < max_concurrent; ++i) { + curr_bar = UCC_TL_CUDA_TEAM_BARRIER(team, max_concurrent + i); + // try to set user specified tag to mark that this barrier is used by this task + if (ucc_atomic_cswap64(&curr_bar->tag, UCC_TL_CUDA_TAG_FREE, + task->bcast_linear.key) == UCC_TL_CUDA_TAG_FREE) { + ucc_debug("Acquire barrier: %p idx: %d marked with tag: %ld", + curr_bar, i, curr_bar->tag); + task->bar = curr_bar; + st = ucc_tl_cuda_shm_barrier_init_root( + task->subset.map.ep_num, task->subset.myrank, + task->bcast_linear.root, task->bar); + if (ucc_unlikely(st != UCC_OK)) { + ucc_error("failed to init root barrier"); + return UCC_ERR_NO_RESOURCE; + } + // Assign a collective ID (index of barrier) + task->coll_id = i + max_concurrent; + return UCC_OK; + } + } + // try next time + return UCC_ERR_NOT_FOUND; +} + +// Peer rank searches for a barrier claimed by the root +static inline ucc_status_t peer_find_free_barrier(ucc_tl_cuda_task_t *task) +{ + ucc_tl_cuda_team_t *team = TASK_TEAM(task); + uint32_t max_concurrent = UCC_TL_CUDA_TEAM_LIB(team)->cfg.max_concurrent; + ucc_tl_cuda_shm_barrier_t *curr_bar; + int i; + ucc_status_t st; + + for (i = 0; i < max_concurrent; ++i) { + curr_bar = UCC_TL_CUDA_TEAM_BARRIER(team, max_concurrent + i); + // Check if the barrier is claimed by the task's root + if (curr_bar->tag == task->bcast_linear.key) { + task->bar = curr_bar; + st = ucc_tl_cuda_shm_barrier_init_root( + task->subset.map.ep_num, task->subset.myrank, + task->bcast_linear.root, task->bar); + if (ucc_unlikely(st != UCC_OK)) { + ucc_error("failed to init peer barrier"); + return UCC_ERR_NO_RESOURCE; + } + task->coll_id = i + max_concurrent; + return UCC_OK; + } + } + // try next time + return UCC_ERR_NOT_FOUND; +} + +static ucc_status_t +ucc_tl_cuda_bcast_linear_finalize(ucc_coll_task_t *coll_task) +{ + ucc_tl_cuda_task_t *task = ucc_derived_of(coll_task, ucc_tl_cuda_task_t); + + tl_trace(UCC_TASK_LIB(task), "finalizing task %p", task); + ucc_tl_cuda_task_put(task); + return UCC_OK; +} + +static void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task) +{ + ucc_tl_cuda_task_t *task = ucc_derived_of(coll_task, ucc_tl_cuda_task_t); + ucc_tl_cuda_team_t *team = TASK_TEAM(task); + ucc_rank_t trank = UCC_TL_TEAM_RANK(team); + size_t half_scratch_size = get_raw_scratch_size(team) / 2; + ucc_rank_t tsize = UCC_COLL_ARGS_ACTIVE_SET(&TASK_ARGS(task)) + ? (ucc_rank_t)task->subset.map.ep_num + : UCC_TL_TEAM_SIZE(team); + size_t chunk_size = ucc_min( + half_scratch_size, + task->bcast_linear.size - task->bcast_linear.step * half_scratch_size); + size_t offset_buff = task->bcast_linear.step * half_scratch_size; + ucc_ee_executor_t *exec; + ucc_ee_executor_task_t *etask; + void *sbuf, *dbuf; + ucc_rank_t peer; + ucc_status_t st; + int i; + + task->super.status = UCC_INPROGRESS; + + st = ucc_coll_task_get_executor(&task->super, &exec); + if (ucc_unlikely(st != UCC_OK)) { + task->super.status = st; + return; + } + + switch (task->bcast_linear.stage) { + case STAGE_INIT_BAR_ROOT: + st = root_find_free_barrier(task); + if (st == UCC_OK) { + task->bcast_linear.stage = STAGE_SYNC; + } else if (st != UCC_ERR_NOT_FOUND) { + task->super.status = st; + } + // no free barriers found, try next time + return; + case STAGE_FIND_BAR_PEER: + st = peer_find_free_barrier(task); + if (st == UCC_OK) { + // barrier found, continue to next stages + task->bcast_linear.stage = STAGE_SYNC; + } else if (st != UCC_ERR_NOT_FOUND) { + task->super.status = st; + } + // no free barriers found by root, try next time + return; + case STAGE_SYNC: + if (ucc_tl_cuda_get_sync_root(task, task->bcast_linear.root) != UCC_OK) { + return; + } + task->bcast_linear.step = 0; + st = ucc_tl_cuda_bcast_linear_setup_start(task); + if (st != UCC_OK) { + task->super.status = st; + return; + } + task->bcast_linear.stage = STAGE_SETUP; + /* fall through */ + case STAGE_SETUP: + st = ucc_tl_cuda_bcast_linear_setup_test(task); + if (st != UCC_OK) { + task->super.status = st; + return; + } + if (trank == task->bcast_linear.root) { + task->bcast_linear.stage = STAGE_COPY; + } else { + task->bcast_linear.stage = STAGE_WAIT_ROOT; + } + /* fall through */ + default: + break; + } + + if (trank == task->bcast_linear.root) { + // Root scenario + // fall-through between cases is intentional + switch (task->bcast_linear.stage) { + case STAGE_COPY: + // copy from src buffer to scratch + dbuf = PTR_OFFSET(TASK_SCRATCH(task, trank), + task->bcast_linear.step % 2 * half_scratch_size); + sbuf = PTR_OFFSET(task->bcast_linear.sbuf, offset_buff); + st = ecopy(dbuf, sbuf, chunk_size, exec, + &task->bcast_linear.exec_task); + if (st != UCC_OK) { + ucc_error("failed to post ecopy task"); + task->super.status = st; + return; + } + task->bcast_linear.stage = STAGE_WAIT_COPY; + /* fall through */ + case STAGE_WAIT_COPY: + etask = task->bcast_linear.exec_task; + ucc_assert(NULL != etask); + st = ucc_ee_executor_task_test(etask); + if (st != UCC_OK) { + return; // not ready + } + ucc_ee_executor_task_finalize(etask); + task->bcast_linear.exec_task = NULL; + // signal others + ++task->bcast_linear.step; + set_rank_step(task, task->bcast_linear.root, + task->bcast_linear.step, 0); + task->bcast_linear.stage = STAGE_WAIT_ALL; + /* fall through */ + case STAGE_WAIT_ALL: + for (i = 0; i < tsize; ++i) { + if (UCC_COLL_ARGS_ACTIVE_SET(&TASK_ARGS(task))) { + // eval phys rank from virt + peer = ucc_ep_map_eval(task->subset.map, i); + } else { + peer = i; + } + // need to wait until all ranks complete step - 1, because of double buffering + if (get_rank_step(task, peer, 0) < + task->bcast_linear.step - 1) { + // rank is not ready, lets wait + return; + } + } + if (task->bcast_linear.step < task->bcast_linear.num_steps) { + // go to next iteration + task->bcast_linear.stage = STAGE_COPY; + return; + } + // finish + st = ucc_tl_cuda_shm_barrier_start(trank, task->bar); + if (ucc_unlikely(st != UCC_OK)) { + ucc_error("failed to start barrier from root rank"); + task->super.status = st; + return; + } + task->bcast_linear.stage = STAGE_WAIT_COMPLETION; + /* fall through */ + case STAGE_WAIT_COMPLETION: + st = ucc_tl_cuda_shm_barrier_test(trank, task->bar); + if (st != UCC_OK) { + // peers still working, lets check next time + task->super.status = st; + return; + } + // set barrier free to unlock others, this is roots responsibility + ucc_debug("Release bar: %p with tag: %ld", task->bar, + task->bar->tag); + task->bar->tag = UCC_TL_CUDA_TAG_FREE; + ucc_tl_cuda_put_sync_root(task, task->bcast_linear.root); + task->super.status = UCC_OK; + break; + default: + ucc_assert(0); + break; + } + } else { + // clients + // fall-through between cases is intentional + switch (task->bcast_linear.stage) { + case STAGE_WAIT_ROOT: + if (get_rank_step(task, task->bcast_linear.root, 0) > + task->bcast_linear.step) { + task->bcast_linear.stage = STAGE_CLIENT_COPY; + break; + } else { + return; + } + /* fall through */ + case STAGE_CLIENT_COPY: + // need to copy from root's scratch buffer + dbuf = PTR_OFFSET(task->bcast_linear.sbuf, offset_buff); + sbuf = PTR_OFFSET(TASK_SCRATCH(task, task->bcast_linear.root), + task->bcast_linear.step % 2 * half_scratch_size); + st = ecopy(dbuf, sbuf, chunk_size, exec, + &task->bcast_linear.exec_task); + if (st != UCC_OK) { + ucc_error("failed to post ecopy task at client"); + task->super.status = st; + return; + } + task->bcast_linear.stage = STAGE_CLIENT_COPY_WAIT; + /* fall through */ + case STAGE_CLIENT_COPY_WAIT: + etask = task->bcast_linear.exec_task; + ucc_assert(NULL != etask); + st = ucc_ee_executor_task_test(etask); + if (st != UCC_OK) { + return; // executor task is not ready + } + ucc_ee_executor_task_finalize(etask); + task->bcast_linear.exec_task = NULL; + ++task->bcast_linear.step; + set_rank_step(task, trank, task->bcast_linear.step, 0); + if (task->bcast_linear.step < task->bcast_linear.num_steps) { + task->bcast_linear.stage = STAGE_WAIT_ROOT; + return; + } + // start barrier to sync with root + st = ucc_tl_cuda_shm_barrier_start(trank, task->bar); + if (ucc_unlikely(st != UCC_OK)) { + ucc_error("failed to start barrier from peer rank"); + task->super.status = st; + return; + } + task->bcast_linear.stage = STAGE_CLIENT_WAIT_COMPLETION; + /* fall through */ + case STAGE_CLIENT_WAIT_COMPLETION: + st = ucc_tl_cuda_shm_barrier_test(trank, task->bar); + if (st != UCC_OK) { + // someone still working, lets check next time + task->super.status = st; + return; + } + task->super.status = UCC_OK; + break; + default: + ucc_assert(0); + break; + } + } +} + +static ucc_status_t ucc_tl_cuda_bcast_linear_start(ucc_coll_task_t *coll_task) +{ + ucc_tl_cuda_task_t *task = ucc_derived_of(coll_task, ucc_tl_cuda_task_t); + ucc_tl_cuda_team_t *team = TASK_TEAM(task); + + task->bcast_linear.step = 0; + task->bcast_linear.stage = STAGE_SYNC; + // in case of active set bcast we need to do additional steps to find free barriers + if (UCC_COLL_ARGS_ACTIVE_SET(&TASK_ARGS(task))) { + task->bcast_linear.stage = + UCC_TL_TEAM_RANK(team) == task->bcast_linear.root + ? STAGE_INIT_BAR_ROOT + : STAGE_FIND_BAR_PEER; + } + + ucc_debug("bcast linear start dt: %s, buffer size: %ld, num_steps: %d", + ucc_datatype_str(task->bcast_linear.dt), task->bcast_linear.size, + task->bcast_linear.num_steps); + + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); +} + +ucc_status_t ucc_tl_cuda_bcast_linear_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *tl_team, + ucc_coll_task_t **task_p) +{ + ucc_tl_cuda_team_t *team = ucc_derived_of(tl_team, ucc_tl_cuda_team_t); + ucc_tl_cuda_task_t *task; + ucc_status_t status; + + if (!ucc_tl_cuda_team_topo_is_fully_connected(team->topo) || + UCC_TL_TEAM_SIZE(team) - 1 > UCC_EE_EXECUTOR_MULTI_OP_NUM_BUFS) { + return UCC_ERR_NOT_SUPPORTED; + } + + status = ucc_tl_cuda_task_init(coll_args, team, &task); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + + task->bcast_linear.root = coll_args->args.root; + task->bcast_linear.dt = coll_args->args.src.info.datatype; + task->bcast_linear.sbuf = coll_args->args.src.info.buffer; + task->bcast_linear.size = + ucc_dt_size(task->bcast_linear.dt) * coll_args->args.src.info.count; + task->bcast_linear.num_steps = ucc_div_round_up( + task->bcast_linear.size, get_raw_scratch_size(team) / 2); + + task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; + task->super.post = ucc_tl_cuda_bcast_linear_start; + task->super.progress = ucc_tl_cuda_bcast_linear_progress; + task->super.finalize = ucc_tl_cuda_bcast_linear_finalize; + + *task_p = &task->super; + return UCC_OK; +} diff --git a/src/components/tl/cuda/reduce_scatter/reduce_scatter.c b/src/components/tl/cuda/reduce_scatter/reduce_scatter.c index 468fd68338..237005c95b 100644 --- a/src/components/tl/cuda/reduce_scatter/reduce_scatter.c +++ b/src/components/tl/cuda/reduce_scatter/reduce_scatter.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -48,7 +48,7 @@ ucc_status_t ucc_tl_cuda_reduce_scatter_init(ucc_base_coll_args_t *coll_args, { ucc_tl_cuda_team_t *team = ucc_derived_of(tl_team, ucc_tl_cuda_team_t); - if (ucc_tl_cuda_team_topo_is_fully_conntected(team->topo)) { + if (ucc_tl_cuda_team_topo_is_fully_connected(team->topo)) { return ucc_tl_cuda_reduce_scatter_linear_init(coll_args, tl_team, task_p); } else { diff --git a/src/components/tl/cuda/reduce_scatter/reduce_scatter_linear.c b/src/components/tl/cuda/reduce_scatter/reduce_scatter_linear.c index 46efbdb051..36801ce1d8 100644 --- a/src/components/tl/cuda/reduce_scatter/reduce_scatter_linear.c +++ b/src/components/tl/cuda/reduce_scatter/reduce_scatter_linear.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -19,7 +19,7 @@ ucc_status_t ucc_tl_cuda_reduce_scatter_linear_init(ucc_base_coll_args_t *coll_a return UCC_ERR_NOT_SUPPORTED; } - if (ucc_unlikely(!ucc_tl_cuda_team_topo_is_fully_conntected(team->topo) || + if (ucc_unlikely(!ucc_tl_cuda_team_topo_is_fully_connected(team->topo) || UCC_TL_TEAM_SIZE(team) - 1 > UCC_EE_EXECUTOR_MULTI_OP_NUM_BUFS)) { return UCC_ERR_NOT_SUPPORTED; } diff --git a/src/components/tl/cuda/reduce_scatterv/reduce_scatterv.c b/src/components/tl/cuda/reduce_scatterv/reduce_scatterv.c index d85e2c8dd3..eef433cdbb 100644 --- a/src/components/tl/cuda/reduce_scatterv/reduce_scatterv.c +++ b/src/components/tl/cuda/reduce_scatterv/reduce_scatterv.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -51,7 +51,7 @@ ucc_status_t ucc_tl_cuda_reduce_scatterv_init(ucc_base_coll_args_t *coll_args, { ucc_tl_cuda_team_t *team = ucc_derived_of(tl_team, ucc_tl_cuda_team_t); - if (ucc_tl_cuda_team_topo_is_fully_conntected(team->topo)) { + if (ucc_tl_cuda_team_topo_is_fully_connected(team->topo)) { return ucc_tl_cuda_reduce_scatterv_linear_init(coll_args, tl_team, task_p); } else { diff --git a/src/components/tl/cuda/reduce_scatterv/reduce_scatterv_linear.c b/src/components/tl/cuda/reduce_scatterv/reduce_scatterv_linear.c index 6a1ec5b22c..56e4e2204c 100644 --- a/src/components/tl/cuda/reduce_scatterv/reduce_scatterv_linear.c +++ b/src/components/tl/cuda/reduce_scatterv/reduce_scatterv_linear.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -59,22 +59,6 @@ enum * other ranks to finish */ }; -static inline int get_rank_step(ucc_tl_cuda_task_t *task, ucc_rank_t rank, - int step_id) -{ - ucc_tl_cuda_sync_t *sync = TASK_SYNC(task, rank); - - return sync->seq_num[step_id]; -} - -static inline void set_rank_step(ucc_tl_cuda_task_t *task, ucc_rank_t rank, - int step, int step_id) -{ - ucc_tl_cuda_sync_t *sync = TASK_SYNC(task, rank); - - sync->seq_num[step_id] = step; -} - ucc_status_t ucc_tl_cuda_reduce_scatterv_linear_finalize(ucc_coll_task_t *coll_task) { @@ -448,7 +432,7 @@ ucc_tl_cuda_reduce_scatterv_linear_init(ucc_base_coll_args_t *coll_args, return UCC_ERR_NOT_SUPPORTED; } - if (ucc_unlikely(!ucc_tl_cuda_team_topo_is_fully_conntected(team->topo) || + if (ucc_unlikely(!ucc_tl_cuda_team_topo_is_fully_connected(team->topo) || UCC_TL_TEAM_SIZE(team) - 1 > UCC_EE_EXECUTOR_MULTI_OP_NUM_BUFS)) { return UCC_ERR_NOT_SUPPORTED; } diff --git a/src/components/tl/cuda/tl_cuda.c b/src/components/tl/cuda/tl_cuda.c index 98dccf26bf..842db59c72 100644 --- a/src/components/tl/cuda/tl_cuda.c +++ b/src/components/tl/cuda/tl_cuda.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -9,6 +9,7 @@ #include "components/mc/base/ucc_mc_base.h" #include "allgather/allgather.h" #include "allgatherv/allgatherv.h" +#include "bcast/bcast.h" #include "reduce_scatter/reduce_scatter.h" #include "reduce_scatterv/reduce_scatterv.h" @@ -93,6 +94,8 @@ __attribute__((constructor)) static void tl_cuda_iface_init(void) ucc_tl_cuda_allgather_algs; ucc_tl_cuda.super.alg_info[ucc_ilog2(UCC_COLL_TYPE_ALLGATHERV)] = ucc_tl_cuda_allgatherv_algs; + ucc_tl_cuda.super.alg_info[ucc_ilog2(UCC_COLL_TYPE_BCAST)] = + ucc_tl_cuda_bcast_algs; ucc_tl_cuda.super.alg_info[ucc_ilog2(UCC_COLL_TYPE_REDUCE_SCATTER)] = ucc_tl_cuda_reduce_scatter_algs; ucc_tl_cuda.super.alg_info[ucc_ilog2(UCC_COLL_TYPE_REDUCE_SCATTERV)] = diff --git a/src/components/tl/cuda/tl_cuda.h b/src/components/tl/cuda/tl_cuda.h index 792100c80c..9742ac8ba2 100644 --- a/src/components/tl/cuda/tl_cuda.h +++ b/src/components/tl/cuda/tl_cuda.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) Meta Platforms, Inc. and affiliates. 2022. * * See file LICENSE for terms. @@ -12,6 +12,7 @@ #include "components/tl/ucc_tl_log.h" #include "components/mc/ucc_mc.h" #include "utils/ucc_mpool.h" +#include "utils/ucc_datastruct.h" #include "tl_cuda_ep_hash.h" #include "tl_cuda_topo.h" #include "tl_cuda_team_topo.h" @@ -27,6 +28,7 @@ #define UCC_TL_CUDA_SUPPORTED_COLLS \ (UCC_COLL_TYPE_ALLTOALL | UCC_COLL_TYPE_ALLTOALLV | \ UCC_COLL_TYPE_ALLGATHER | UCC_COLL_TYPE_ALLGATHERV | \ + UCC_COLL_TYPE_BCAST | \ UCC_COLL_TYPE_REDUCE_SCATTER | UCC_COLL_TYPE_REDUCE_SCATTERV) #define UCC_TL_CUDA_TEAM_LIB(_team) \ @@ -72,7 +74,7 @@ extern ucc_tl_cuda_iface_t ucc_tl_cuda; typedef struct ucc_tl_cuda_lib_config { ucc_tl_lib_config_t super; - uint32_t max_concurrent; + uint32_t max_concurrent; // Maximum number of tasks that can be progressed simultaneously. size_t scratch_size; unsigned long allgather_ring_max_rings; uint32_t allgather_ring_num_chunks; @@ -104,9 +106,12 @@ UCC_CLASS_DECLARE(ucc_tl_cuda_context_t, const ucc_base_context_params_t *, typedef uint32_t ucc_tl_cuda_sync_state_t; +#define UCC_TL_CUDA_TAG_FREE 0xFFFFFFFFFFFFFFFF + typedef struct ucc_tl_cuda_shm_barrier { ucc_rank_t size; ucc_rank_t count; + uint64_t tag; int sense; ucc_status_t state[UCC_TL_CUDA_MAX_PEERS]; int local_sense[UCC_TL_CUDA_MAX_PEERS]; @@ -152,13 +157,15 @@ typedef struct ucc_tl_cuda_scratch { ucc_tl_cuda_mem_info_t rem_info[UCC_TL_CUDA_MAX_PEERS]; } ucc_tl_cuda_scratch_t; +// Team represents a communicator created within the CUDA context, typically using NVLink for inter-GPU communication typedef struct ucc_tl_cuda_team { ucc_tl_team_t super; - uint32_t seq_num; + uint32_t seq_num; // Counter for the number of launched collective tasks for this team + uint32_t seq_num_active_set; // Counter for tasks in the active set (subset of tasks requiring special handling) ucc_tl_cuda_team_topo_t *topo; - ucc_tl_cuda_sync_t *sync; - ucc_tl_cuda_sync_state_t *sync_state; - ucc_tl_cuda_shm_barrier_t *bar; + ucc_tl_cuda_sync_t *sync; // Pointer to shared memory segment for synchronization + ucc_tl_cuda_sync_state_t *sync_state; // Tracks the task currently using the sync segment of shared memory, if free - 0 + ucc_tl_cuda_shm_barrier_t *bar; // Pointer to the first barrier in an array of size [0; 2 * max_concurrent]. First max_concurrent barriers are for normal mode, the second one for active set mode ucc_tl_cuda_scratch_t scratch; cudaStream_t stream; ucc_tl_cuda_rank_id_t *ids; @@ -169,12 +176,14 @@ typedef struct ucc_tl_cuda_team { UCC_CLASS_DECLARE(ucc_tl_cuda_team_t, ucc_base_context_t *, const ucc_base_team_params_t *); +// Task represents a collective operation that runs in the CUDA context, typically using NVLink for inter-GPU communication typedef struct ucc_tl_cuda_task ucc_tl_cuda_task_t; struct ucc_tl_cuda_task { ucc_coll_task_t super; - uint32_t seq_num; - uint32_t coll_id; - ucc_tl_cuda_shm_barrier_t *bar; + uint32_t seq_num; // Sequential identifier for each task started within the team + uint32_t coll_id; // Index of the collective task in flight, within the range [0; max_concurrent) + ucc_tl_cuda_shm_barrier_t *bar; // Pointer to the reserved barrier for this task in the CUDA team + ucc_subset_t subset; // Mapping information for the active set, if it is present union { struct { int stage; @@ -224,6 +233,17 @@ struct ucc_tl_cuda_task { size_t (*get_offset)(const ucc_tl_cuda_task_t *task, ucc_rank_t block); } allgatherv_linear; + struct { + int stage; + int step; + void *sbuf; + ucc_datatype_t dt; + ucc_rank_t root; + size_t size; + int num_steps; + ucc_ee_executor_task_t *exec_task; + uint64_t key; // This is mix of user provided tag, root and peer to be unique for each task, algorithm uses it to mark barrier as used + } bcast_linear; struct { int stage; int num_frags; diff --git a/src/components/tl/cuda/tl_cuda_coll.c b/src/components/tl/cuda/tl_cuda_coll.c index 5d01cc1a94..71325dc826 100644 --- a/src/components/tl/cuda/tl_cuda_coll.c +++ b/src/components/tl/cuda/tl_cuda_coll.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -9,6 +9,7 @@ #include "alltoallv/alltoallv.h" #include "allgather/allgather.h" #include "allgatherv/allgatherv.h" +#include "bcast/bcast.h" #include "reduce_scatter/reduce_scatter.h" #include "reduce_scatterv/reduce_scatterv.h" #include "utils/arch/cpu.h" @@ -35,6 +36,7 @@ const char * ucc_tl_cuda_default_alg_select_str[UCC_TL_CUDA_N_DEFAULT_ALG_SELECT_STR] = { UCC_TL_CUDA_ALLGATHER_DEFAULT_ALG_SELECT_STR, UCC_TL_CUDA_ALLGATHERV_DEFAULT_ALG_SELECT_STR, + UCC_TL_CUDA_BCAST_DEFAULT_ALG_SELECT_STR, UCC_TL_CUDA_REDUCE_SCATTER_DEFAULT_ALG_SELECT_STR, UCC_TL_CUDA_REDUCE_SCATTERV_DEFAULT_ALG_SELECT_STR}; @@ -78,6 +80,8 @@ ucc_status_t ucc_tl_cuda_coll_init(ucc_base_coll_args_t *coll_args, return ucc_tl_cuda_allgather_init(coll_args, team, task_h); case UCC_COLL_TYPE_ALLGATHERV: return ucc_tl_cuda_allgatherv_init(coll_args, team, task_h); + case UCC_COLL_TYPE_BCAST: + return ucc_tl_cuda_bcast_init(coll_args, team, task_h); case UCC_COLL_TYPE_REDUCE_SCATTER: return ucc_tl_cuda_reduce_scatter_init(coll_args, team, task_h); case UCC_COLL_TYPE_REDUCE_SCATTERV: @@ -89,6 +93,19 @@ ucc_status_t ucc_tl_cuda_coll_init(ucc_base_coll_args_t *coll_args, } } +ucc_status_t ucc_tl_cuda_shm_barrier_init_root(ucc_rank_t size, ucc_rank_t rank, ucc_rank_t root, + ucc_tl_cuda_shm_barrier_t *barrier) +{ + if (rank == root) { + barrier->size = size; + barrier->count = 0; + barrier->sense = 0; + } + barrier->state[rank] = UCC_OK; + barrier->local_sense[rank] = 1; + return UCC_OK; +} + ucc_status_t ucc_tl_cuda_shm_barrier_init(ucc_rank_t size, ucc_rank_t rank, ucc_tl_cuda_shm_barrier_t *barrier) { @@ -134,6 +151,8 @@ static inline int alg_id_from_str(ucc_coll_type_t coll_type, const char *str) return ucc_tl_cuda_allgather_alg_from_str(str); case UCC_COLL_TYPE_ALLGATHERV: return ucc_tl_cuda_allgatherv_alg_from_str(str); + case UCC_COLL_TYPE_BCAST: + return ucc_tl_cuda_bcast_alg_from_str(str); default: break; } @@ -187,6 +206,16 @@ ucc_status_t ucc_tl_cuda_alg_id_to_init(int alg_id, const char *alg_id_str, break; }; break; + case UCC_COLL_TYPE_BCAST: + switch (alg_id) { + case UCC_TL_CUDA_BCAST_ALG_LINEAR: + *init = ucc_tl_cuda_bcast_linear_init; + break; + default: + status = UCC_ERR_INVALID_PARAM; + break; + }; + break; case UCC_COLL_TYPE_REDUCE_SCATTER: switch (alg_id) { case UCC_TL_CUDA_REDUCE_SCATTER_ALG_AUTO: diff --git a/src/components/tl/cuda/tl_cuda_coll.h b/src/components/tl/cuda/tl_cuda_coll.h index 8b15cdf249..8ad8fca5f6 100644 --- a/src/components/tl/cuda/tl_cuda_coll.h +++ b/src/components/tl/cuda/tl_cuda_coll.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -10,7 +10,7 @@ #include "tl_cuda.h" #include "components/mc/ucc_mc.h" -#define UCC_TL_CUDA_N_DEFAULT_ALG_SELECT_STR 4 +#define UCC_TL_CUDA_N_DEFAULT_ALG_SELECT_STR 5 extern const char *ucc_tl_cuda_default_alg_select_str[UCC_TL_CUDA_N_DEFAULT_ALG_SELECT_STR]; @@ -50,6 +50,19 @@ static inline void ucc_tl_cuda_task_reset(ucc_tl_cuda_task_t *task) task->super.status = UCC_INPROGRESS; } +ucc_status_t ucc_tl_cuda_shm_barrier_init_root(ucc_rank_t size, ucc_rank_t rank, ucc_rank_t root, + ucc_tl_cuda_shm_barrier_t *barrier); + +ucc_status_t ucc_tl_cuda_shm_barrier_init(ucc_rank_t size, ucc_rank_t rank, + ucc_tl_cuda_shm_barrier_t *barrier); + +ucc_status_t ucc_tl_cuda_shm_barrier_start(ucc_rank_t rank, + ucc_tl_cuda_shm_barrier_t *barrier); + +ucc_status_t ucc_tl_cuda_shm_barrier_test(ucc_rank_t rank, + ucc_tl_cuda_shm_barrier_t *barrier); + + static inline ucc_tl_cuda_task_t *ucc_tl_cuda_task_get(ucc_tl_cuda_team_t *team) { ucc_tl_cuda_context_t *ctx = UCC_TL_CUDA_TEAM_CTX(team); @@ -74,6 +87,13 @@ static inline void ucc_tl_cuda_task_put(ucc_tl_cuda_task_t *task) ucc_mpool_put(task); } +static inline uint64_t compute_key(ucc_rank_t root, ucc_rank_t peer, uint16_t tag) +{ + assert(peer < (1 << 24)); + assert(root < (1 << 24)); + return (uint64_t)tag << 48 | root << 24 | peer; +} + static inline ucc_status_t ucc_tl_cuda_task_init(ucc_base_coll_args_t *coll_args, ucc_tl_cuda_team_t *team, @@ -82,6 +102,7 @@ ucc_status_t ucc_tl_cuda_task_init(ucc_base_coll_args_t *coll_args, ucc_rank_t trank = UCC_TL_TEAM_RANK(team); ucc_tl_cuda_lib_t *lib = UCC_TL_CUDA_TEAM_LIB(team); uint32_t max_concurrent = lib->cfg.max_concurrent; + ucc_rank_t peer; ucc_tl_cuda_task_t *task; ucc_status_t status; @@ -100,19 +121,34 @@ ucc_status_t ucc_tl_cuda_task_init(ucc_base_coll_args_t *coll_args, return status; } - task->seq_num = team->seq_num++; - task->coll_id = task->seq_num % max_concurrent; + /* active set */ + if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) { + ucc_assert(coll_args->args.coll_type == UCC_COLL_TYPE_BCAST); + task->subset.map = ucc_active_set_to_ep_map(&coll_args->args); + task->subset.myrank = UCC_TL_TEAM_RANK(team); + // currently we support only active set bacst with 2 ranks + // so root rank should remap phys rank of peer with rank 1 + peer = (task->subset.myrank == coll_args->args.root) ? ucc_ep_map_eval(task->subset.map, 1) : task->subset.myrank; + task->bcast_linear.key = compute_key(coll_args->args.root, peer, coll_args->args.tag); + task->seq_num = team->seq_num_active_set++; + } else { + task->seq_num = team->seq_num++; + task->coll_id = task->seq_num % max_concurrent; + task->bar = TASK_BAR(task); + } *task_h = task; return UCC_OK; } -static inline ucc_status_t ucc_tl_cuda_get_sync(ucc_tl_cuda_task_t *task) +// check if segment for current task is available and barrier is available (completed from prev iteration) +// and possibly mark the segment as occupied by updating the state counter to the current seq_num +static inline ucc_status_t ucc_tl_cuda_get_sync_root(ucc_tl_cuda_task_t *task, ucc_rank_t root) { ucc_tl_cuda_team_t *team = TASK_TEAM(task); volatile ucc_tl_cuda_sync_state_t *state = &team->sync_state[task->coll_id]; - if ((UCC_TL_TEAM_RANK(team) == 0) && (*state == 0)) { + if ((UCC_TL_TEAM_RANK(team) == root) && (*state == 0)) { *state = task->seq_num; } if ((*state != task->seq_num) || @@ -122,17 +158,27 @@ static inline ucc_status_t ucc_tl_cuda_get_sync(ucc_tl_cuda_task_t *task) return UCC_OK; } -static inline void ucc_tl_cuda_put_sync(ucc_tl_cuda_task_t *task) +static inline void ucc_tl_cuda_put_sync_root(ucc_tl_cuda_task_t *task, ucc_rank_t root) { ucc_tl_cuda_team_t *team = TASK_TEAM(task); ucc_tl_cuda_sync_state_t *state = &team->sync_state[task->coll_id]; - if (UCC_TL_TEAM_RANK(team) == 0) { + if (UCC_TL_TEAM_RANK(team) == root) { ucc_assert(*state == task->seq_num); *state = 0; } } +static inline ucc_status_t ucc_tl_cuda_get_sync(ucc_tl_cuda_task_t *task) +{ + return ucc_tl_cuda_get_sync_root(task, 0); +} + +static inline void ucc_tl_cuda_put_sync(ucc_tl_cuda_task_t *task) +{ + ucc_tl_cuda_put_sync_root(task, 0); +} + ucc_status_t ucc_tl_cuda_mem_info_get(void *ptr, size_t length, ucc_tl_cuda_mem_info_t *mi); @@ -142,18 +188,26 @@ ucc_status_t ucc_tl_cuda_coll_init(ucc_base_coll_args_t *coll_args, ucc_status_t ucc_tl_cuda_coll_finalize(ucc_coll_task_t *coll_task); -ucc_status_t ucc_tl_cuda_shm_barrier_init(ucc_rank_t size, ucc_rank_t rank, - ucc_tl_cuda_shm_barrier_t *barrier); - -ucc_status_t ucc_tl_cuda_shm_barrier_start(ucc_rank_t rank, - ucc_tl_cuda_shm_barrier_t *barrier); - -ucc_status_t ucc_tl_cuda_shm_barrier_test(ucc_rank_t rank, - ucc_tl_cuda_shm_barrier_t *barrier); - ucc_status_t ucc_tl_cuda_alg_id_to_init(int alg_id, const char *alg_id_str, ucc_coll_type_t coll_type, ucc_memory_type_t mem_type, ucc_base_coll_init_fn_t *init); +// common utils function for collectives: +static inline int get_rank_step(ucc_tl_cuda_task_t *task, ucc_rank_t rank, + int step_id) +{ + ucc_tl_cuda_sync_t *sync = TASK_SYNC(task, rank); + + return sync->seq_num[step_id]; +} + +static inline void set_rank_step(ucc_tl_cuda_task_t *task, ucc_rank_t rank, + int step, int step_id) +{ + ucc_tl_cuda_sync_t *sync = TASK_SYNC(task, rank); + + sync->seq_num[step_id] = step; +} + #endif diff --git a/src/components/tl/cuda/tl_cuda_ring.h b/src/components/tl/cuda/tl_cuda_ring.h index cc2d3c95db..13835df99d 100644 --- a/src/components/tl/cuda/tl_cuda_ring.h +++ b/src/components/tl/cuda/tl_cuda_ring.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -83,20 +83,4 @@ static inline ucc_rank_t get_recv_block(ucc_tl_cuda_team_t *team, return ring->ring[(ring->iring[trank] + tsize - step - 1) % tsize]; } -static inline int get_rank_step(ucc_tl_cuda_task_t *task, ucc_rank_t rank, - int ring_id) -{ - ucc_tl_cuda_sync_t *sync = TASK_SYNC(task, rank); - - return sync->seq_num[ring_id]; -} - -static inline void set_rank_step(ucc_tl_cuda_task_t *task, ucc_rank_t rank, - int step, int ring_id) -{ - ucc_tl_cuda_sync_t *sync = TASK_SYNC(task, rank); - - sync->seq_num[ring_id] = step; -} - #endif diff --git a/src/components/tl/cuda/tl_cuda_team.c b/src/components/tl/cuda/tl_cuda_team.c index 64123a8cea..3b8a5fb253 100644 --- a/src/components/tl/cuda/tl_cuda_team.c +++ b/src/components/tl/cuda/tl_cuda_team.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -22,6 +22,8 @@ UCC_CLASS_INIT_FUNC(ucc_tl_cuda_team_t, ucc_base_context_t *tl_context, ucc_derived_of(tl_context, ucc_tl_cuda_context_t); ucc_tl_cuda_lib_t *lib = ucc_derived_of(tl_context->lib, ucc_tl_cuda_lib_t); + // Number of preallocated resource groups for tasks, including the active set. + uint32_t resource_num = lib->cfg.max_concurrent * 2; ucc_tl_cuda_shm_barrier_t *bar; ucc_status_t status; int shm_id, i, j; @@ -45,7 +47,8 @@ UCC_CLASS_INIT_FUNC(ucc_tl_cuda_team_t, ucc_base_context_t *tl_context, return UCC_ERR_NO_MEMORY; } - scratch_size = lib->cfg.max_concurrent * lib->cfg.scratch_size; + // active set + scratch_size = resource_num * lib->cfg.scratch_size; status = CUDA_FUNC(cudaMalloc(&self->scratch.loc, scratch_size)); if (status != UCC_OK) { tl_error(tl_context->lib, "failed to alloc scratch buffer"); @@ -64,6 +67,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_cuda_team_t, ucc_base_context_t *tl_context, lib->cfg.max_concurrent + sizeof(ucc_tl_cuda_shm_barrier_t) * lib->cfg.max_concurrent + sizeof(ucc_tl_cuda_sync_state_t) * lib->cfg.max_concurrent; + ctrl_size *= 2; // active sets shm_id = -1; self->sync = (void*)-1; @@ -77,10 +81,12 @@ UCC_CLASS_INIT_FUNC(ucc_tl_cuda_team_t, ucc_base_context_t *tl_context, goto ids_exchange; } memset(self->sync, 0, ctrl_size); - self->bar = (ucc_tl_cuda_shm_barrier_t*)UCC_TL_CUDA_TEAM_SYNC(self, 0, - lib->cfg.max_concurrent); - for (i = 0; i < lib->cfg.max_concurrent; i++) { + self->bar = (ucc_tl_cuda_shm_barrier_t *)UCC_TL_CUDA_TEAM_SYNC( + self, 0, resource_num); + /* active set */ + for (i = 0; i < resource_num; i++) { bar = UCC_TL_CUDA_TEAM_BARRIER(self, i); + bar->tag = UCC_TL_CUDA_TAG_FREE; // mark as free for (j = 0; j < UCC_TL_TEAM_SIZE(self); j++) { status = ucc_tl_cuda_shm_barrier_init(UCC_TL_TEAM_SIZE(self), j, bar); @@ -109,6 +115,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_cuda_team_t, ucc_base_context_t *tl_context, tl_debug(tl_context->lib, "posted tl team: %p", self); self->seq_num = 1; + self->seq_num_active_set = 1; return UCC_OK; free_devices: @@ -127,6 +134,8 @@ UCC_CLASS_CLEANUP_FUNC(ucc_tl_cuda_team_t) { ucc_tl_cuda_lib_t *lib = ucc_derived_of(self->super.super.context->lib, ucc_tl_cuda_lib_t); + // Number of preallocated resource groups for tasks, including the active set. + uint32_t resource_num = lib->cfg.max_concurrent * 2; ucc_tl_cuda_sync_t *sync; cudaError_t st; int i, j; @@ -137,7 +146,7 @@ UCC_CLASS_CLEANUP_FUNC(ucc_tl_cuda_team_t) } if (self->ids) { if (self->sync != (void*)-1) { - for (i = 0; i < lib->cfg.max_concurrent; i++) { + for (i = 0; i < resource_num; i++) { for (j = 0; j < UCC_TL_TEAM_SIZE(self); j++) { if (j == UCC_TL_TEAM_RANK(self)) { continue; @@ -199,6 +208,8 @@ ucc_status_t ucc_tl_cuda_team_create_test(ucc_base_team_t *tl_team) ucc_tl_cuda_team_t *team = ucc_derived_of(tl_team, ucc_tl_cuda_team_t); ucc_tl_cuda_lib_t *lib = ucc_derived_of(tl_team->context->lib, ucc_tl_cuda_lib_t); + // Number of preallocated resource groups for tasks, including the active set. + uint32_t resource_num = lib->cfg.max_concurrent * 2; ucc_status_t status; ucc_tl_cuda_sync_t *sync; ucc_tl_cuda_shm_barrier_t *bar; @@ -268,14 +279,14 @@ ucc_status_t ucc_tl_cuda_team_create_test(ucc_base_team_t *tl_team) goto exit_err; } team->bar = (ucc_tl_cuda_shm_barrier_t*)UCC_TL_CUDA_TEAM_SYNC(team, 0, - lib->cfg.max_concurrent); + resource_num); } team->sync_state = (ucc_tl_cuda_sync_state_t*)PTR_OFFSET(team->bar, sizeof(ucc_tl_cuda_shm_barrier_t) * - lib->cfg.max_concurrent); + resource_num); CUDA_CHECK_GOTO(cudaStreamCreateWithFlags(&team->stream, cudaStreamNonBlocking), exit_err, status); - for (i = 0; i < lib->cfg.max_concurrent; i++) { + for (i = 0; i < resource_num; i++) { sync = UCC_TL_CUDA_TEAM_SYNC(team, UCC_TL_TEAM_RANK(team), i); CUDA_CHECK_GOTO(cudaEventCreateWithFlags(&sync->ipc_event_local, cudaEventDisableTiming | @@ -303,7 +314,7 @@ ucc_status_t ucc_tl_cuda_team_create_test(ucc_base_team_t *tl_team) goto exit_err; } - for (i = 0; i < lib->cfg.max_concurrent; i++) { + for (i = 0; i < resource_num; i++) { sync = UCC_TL_CUDA_TEAM_SYNC(team, UCC_TL_TEAM_RANK(team), i); for (j = 0 ; j < UCC_TL_TEAM_SIZE(team); j++) { if (j == UCC_TL_TEAM_RANK(team)) { diff --git a/src/components/tl/cuda/tl_cuda_team_topo.h b/src/components/tl/cuda/tl_cuda_team_topo.h index 96b6d63a5b..a56b28bf21 100644 --- a/src/components/tl/cuda/tl_cuda_team_topo.h +++ b/src/components/tl/cuda/tl_cuda_team_topo.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -51,7 +51,7 @@ ucc_tl_cuda_team_topo_is_direct(const ucc_tl_team_t *team, } static inline int -ucc_tl_cuda_team_topo_is_fully_conntected(const ucc_tl_cuda_team_topo_t *topo) +ucc_tl_cuda_team_topo_is_fully_connected(const ucc_tl_cuda_team_topo_t *topo) { return topo->is_fully_connected; } diff --git a/src/components/tl/ucp/bcast/bcast_knomial.c b/src/components/tl/ucp/bcast/bcast_knomial.c index 1ca08893e3..62430024bf 100644 --- a/src/components/tl/ucp/bcast/bcast_knomial.c +++ b/src/components/tl/ucp/bcast/bcast_knomial.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -22,7 +22,7 @@ void ucc_tl_ucp_bcast_knomial_progress(ucc_coll_task_t *coll_task) ucc_rank_t size = (ucc_rank_t)task->subset.map.ep_num; uint32_t radix = task->bcast_kn.radix; - ucc_rank_t root = (uint32_t)TASK_ARGS(task).root; + ucc_rank_t root = (ucc_rank_t)TASK_ARGS(task).root; ucc_rank_t dist = task->bcast_kn.dist; void *buffer = TASK_ARGS(task).src.info.buffer; ucc_memory_type_t mtype = TASK_ARGS(task).src.info.mem_type; diff --git a/test/gtest/coll/test_bcast.cc b/test/gtest/coll/test_bcast.cc index 6d80816a31..00a9cd874e 100644 --- a/test/gtest/coll/test_bcast.cc +++ b/test/gtest/coll/test_bcast.cc @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ @@ -276,6 +276,8 @@ ucc_job_env_t two_step_env = {{"UCC_CL_HIER_TUNE", "bcast:@2step:0-inf:inf"}, {"UCC_CLS", "all"}}; ucc_job_env_t dbt_env = {{"UCC_TL_UCP_TUNE", "bcast:@dbt:0-inf:inf"}, {"UCC_CLS", "basic"}}; +ucc_job_env_t cuda_env = {{"UCC_TL_CUDA_TUNE", "bcast:cuda:@0:0-inf:inf"}, + {"UCC_CLS", "basic"}}; INSTANTIATE_TEST_CASE_P( , test_bcast_alg, ::testing::Combine( @@ -285,6 +287,10 @@ INSTANTIATE_TEST_CASE_P( #else ::testing::Values(UCC_MEMORY_TYPE_HOST), #endif +#ifdef HAVE_CUDA + ::testing::Values(two_step_env, dbt_env, cuda_env), //env +#else ::testing::Values(two_step_env, dbt_env), //env +#endif ::testing::Values(8, 65536), // count - ::testing::Values(15,16))); // n_procs + ::testing::Values(15, 16))); // n_procs From 7405dfbdd08951f0b9761bb778fe09b42b86c5b3 Mon Sep 17 00:00:00 2001 From: Mamzi Bayatpour <77160721+MamziB@users.noreply.github.com> Date: Tue, 11 Feb 2025 11:28:05 -0800 Subject: [PATCH 09/11] CI: enhance ASAN/LSAN memory leak detections (#1074) --- .github/workflows/clang-tidy.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/clang-tidy.yaml b/.github/workflows/clang-tidy.yaml index f10720bee3..d46054ae56 100644 --- a/.github/workflows/clang-tidy.yaml +++ b/.github/workflows/clang-tidy.yaml @@ -43,6 +43,8 @@ jobs: rm -rf /tmp/ucc - name: Run gtest ASAN run: | + export ASAN_OPTIONS=fast_unwind_on_malloc=0:detect_leaks=1:print_suppressions=0 + export LSAN_OPTIONS=report_objects=1 cd ${GITHUB_WORKSPACE} ls -la echo $PWD From bc996dd4f9b06ede96744613b0270772c997368e Mon Sep 17 00:00:00 2001 From: Mamzi Bayatpour Date: Wed, 18 Dec 2024 11:52:13 -0800 Subject: [PATCH 10/11] TL/MLX5: generate schedule for zcopy allgather --- src/components/tl/mlx5/mcast/tl_mlx5_mcast.h | 6 + .../tl/mlx5/mcast/tl_mlx5_mcast_allgather.c | 173 ++++++++++++++++++ .../tl/mlx5/mcast/tl_mlx5_mcast_team.c | 4 + src/components/tl/mlx5/tl_mlx5.c | 9 + 4 files changed, 192 insertions(+) diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 49f2292166..db3e0eadcb 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -117,6 +117,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm_init_spec { int max_eager; int cuda_mem_enabled; int one_sided_reliability_enable; + int truly_zero_copy_allgather_enabled; + int mcast_prepost_bucket_size; void *oob; } ucc_tl_mlx5_mcast_coll_comm_init_spec_t; @@ -279,6 +281,8 @@ typedef struct ucc_tl_mlx5_mcast_allgather_comm { uint32_t coll_counter; uint32_t max_num_packets; uint32_t max_push_send; + uint8_t truly_zero_copy_allgather_enabled; + uint32_t mcast_prepost_bucket_size; } ucc_tl_mlx5_mcast_allgather_comm_t; typedef struct ucc_tl_mlx5_mcast_bcast_comm { @@ -431,6 +435,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_req { ucc_memory_type_t buf_mem_type; enum ucc_tl_mlx5_mcast_one_sided_reliability_scheme one_sided_reliability_scheme; uint32_t ag_counter; + int concurrency_level; + int mcast_prepost_bucket_size; int state; ucc_tl_mlx5_mcast_pipelined_ag_schedule_t *ag_schedule; int total_steps; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c index 82592238d4..e5ef8d7359 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c @@ -270,6 +270,169 @@ void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task) } } +static inline ucc_status_t +ucc_tl_mlx5_mcast_validate_zero_copy_allgather_params(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req) +{ + + if (req->concurrency_level % 2 == 0 && req->num_packets % req->mcast_prepost_bucket_size != 0) { + tl_warn(comm->lib, "Pipelined mcast allgather not supported: " + "num_packets (%d) must be a multiple of mcast_prepost_bucket_size (%d) " + "when concurrency_level (%d) is even.", + req->num_packets, req->mcast_prepost_bucket_size, req->concurrency_level); + return UCC_ERR_NOT_SUPPORTED; + } + + if (comm->commsize % req->concurrency_level != 0) { + tl_warn(comm->lib, "Pipelined mcast allgather not supported: " + "team size (%d) must be a multiple of concurrency_level (%d).", + comm->commsize, req->concurrency_level); + return UCC_ERR_NOT_SUPPORTED; + } + + if (req->length % comm->max_per_packet != 0) { + tl_warn(comm->lib, "Pipelined mcast allgather not supported: " + "length (%ld) must be a multiple of max_per_packet (%d).", + req->length, comm->max_per_packet); + return UCC_ERR_NOT_SUPPORTED; + } + + if (req->mcast_prepost_bucket_size * req->concurrency_level * 2 > comm->params.rx_depth) { + tl_warn(comm->lib, "Pipelined mcast allgather not supported: " + "we only support the case prepost_bucket_size * concurrency_level * 2 > rx_depth, " + "but got: prepost_bucket_size=%d, concurrency_level=%d, " + "rx_depth=%d", + req->mcast_prepost_bucket_size, req->concurrency_level, + comm->params.rx_depth); + return UCC_ERR_NOT_SUPPORTED; + } + + return UCC_OK; +} + + +/* + * at each stage half of the mcast groups are ready for receiving mcast + * packets while the other half are getting prepared by preposting recv + * buffers + */ +static inline ucc_status_t +ucc_tl_mlx5_mcast_prepare_zero_copy_allgather(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req) +{ + ucc_tl_mlx5_mcast_reg_t *reg = NULL; + ucc_rank_t root = 0; + int offset = 0; + ucc_status_t status; + ucc_rank_t j, i; + int total_steps; + ucc_tl_mlx5_mcast_pipelined_ag_schedule_t *schedule; + + ucc_assert(comm->allgather_comm.truly_zero_copy_allgather_enabled); + + req->concurrency_level = comm->mcast_group_count / 2; + req->concurrency_level = ucc_min(req->concurrency_level, ONE_SIDED_MAX_CONCURRENT_LEVEL); + req->concurrency_level = ucc_min(req->concurrency_level, comm->commsize); + + if (req->concurrency_level == 0) { + tl_warn(comm->lib, "not enough concurreny level to enable zcopy pipeline allgather"); + return UCC_ERR_NOT_SUPPORTED; + } + + req->mcast_prepost_bucket_size = + ucc_min(req->num_packets, comm->allgather_comm.mcast_prepost_bucket_size); + + status = ucc_tl_mlx5_mcast_validate_zero_copy_allgather_params(comm, req); + if (status != UCC_OK) { + return status; + } + + /* calculate the schedule and details of what we should + * mcast and prepost to which mcast group at each stage*/ + total_steps = req->num_packets * (comm->commsize / req->concurrency_level) + / req->mcast_prepost_bucket_size + 1; + + schedule = ucc_calloc(1, + sizeof(ucc_tl_mlx5_mcast_pipelined_ag_schedule_t) * + total_steps, "sched"); + if (!schedule) { + tl_warn(comm->lib, "cannot allocate memory for schedule list"); + return UCC_ERR_NO_MEMORY; + } + + /* generate schedule */ + for (i = 0; i < total_steps; i++) { + if (i < total_steps - 1) { + for (j = 0; j < req->concurrency_level; j++) { + schedule[i].prepost_buf_op[j].group_id = + j + req->concurrency_level * (i % 2); + schedule[i].prepost_buf_op[j].offset = + offset * comm->max_per_packet; + schedule[i].prepost_buf_op[j].root = root + j; + schedule[i].prepost_buf_op[j].count = + req->mcast_prepost_bucket_size; + } + } else { + schedule[i].prepost_buf_op_done = 1; + } + + if (i > 0) { + for (j = 0; j < req->concurrency_level; j++) { + schedule[i].multicast_op[j].group_id = + schedule[i - 1].prepost_buf_op[j].group_id; + schedule[i].multicast_op[j].offset = + schedule[i - 1].prepost_buf_op[j].offset; + schedule[i].multicast_op[j].offset_left = + schedule[i - 1].prepost_buf_op[j].offset; + schedule[i].multicast_op[j].root = + schedule[i - 1].prepost_buf_op[j].root; + schedule[i].multicast_op[j].to_send_left = + schedule[i - 1].prepost_buf_op[j].count; + schedule[i].multicast_op[j].to_recv = + schedule[i - 1].prepost_buf_op[j].count; + schedule[i].to_recv += schedule[i].multicast_op[j].to_recv; + if (schedule[i].multicast_op[j].root == comm->rank) { + schedule[i].to_send += schedule[i].multicast_op[j].to_send_left; + } + } + } + + if (!schedule[i].to_send || !schedule[i].to_recv) { + schedule[i].multicast_op_done = 1; + } + + offset += req->mcast_prepost_bucket_size; + + if (offset == req->num_packets) { + offset = 0; + root = (root + req->concurrency_level) % comm->commsize; + } + } + + tl_trace(comm->lib, + "generated the schedule for pipelined zero copy allgather with total_steps %d", + total_steps); + schedule->total_steps = total_steps; + req->total_steps = total_steps; + req->ag_schedule = schedule; + tl_trace(comm->lib, "registering recv buf of size %ld", req->length * comm->commsize); + ucc_assert(req->recv_rreg == NULL); + + status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->rptr, req->length * + comm->commsize, ®); + if (UCC_OK != status) { + tl_warn(comm->lib, "unable to register receive buffer %p of size %ld", + req->rptr, req->length * comm->commsize); + ucc_free(schedule); + return status; + } + + req->recv_rreg = reg; + req->recv_mr = reg->mr; + + return UCC_OK; +} + ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task) { ucc_coll_task_t *coll_task = &(task->super); @@ -357,6 +520,13 @@ ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task) req->to_send = req->num_packets; req->to_recv = comm->commsize * req->num_packets; + if (comm->allgather_comm.truly_zero_copy_allgather_enabled) { + status = ucc_tl_mlx5_mcast_prepare_zero_copy_allgather(comm, req); + if (UCC_OK != status) { + goto failed; + } + } + comm->allgather_comm.coll_counter++; task->coll_mcast.req_handle = req; @@ -368,6 +538,9 @@ ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task) failed: tl_warn(UCC_TASK_LIB(task), "mcast init allgather failed:%d", status); if (req) { + if (req->rreg) { + ucc_tl_mlx5_mcast_mem_deregister(comm->ctx, req->rreg); + } ucc_mpool_put(req); } return status; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index 84efb5daf1..e7b6014cb4 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -99,6 +99,10 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, memcpy(&comm->params, conf_params, sizeof(*conf_params)); + comm->allgather_comm.mcast_prepost_bucket_size + = conf_params->mcast_prepost_bucket_size; + comm->allgather_comm.truly_zero_copy_allgather_enabled + = conf_params->truly_zero_copy_allgather_enabled; comm->one_sided.reliability_enabled = conf_params->one_sided_reliability_enable; comm->bcast_comm.wsize = conf_params->wsize; comm->allgather_comm.max_push_send = conf_params->max_push_send; diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index 5cdd6c51a1..b9d48edb7b 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -104,6 +104,15 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.one_sided_reliability_enable), UCC_CONFIG_TYPE_BOOL}, + {"MCAST_ZERO_COPY_ALLGATHER_ENABLE", "1", "Enable truly zero copy allgather design for mcast", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.truly_zero_copy_allgather_enabled), + UCC_CONFIG_TYPE_BOOL}, + + {"MCAST_ZERO_COPY_PREPOST_BUCKET_SIZE", "16", "Number of posted recvs during each stage of the pipeline" + " in truly zero copy mcast allgather design", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.mcast_prepost_bucket_size), + UCC_CONFIG_TYPE_INT}, + {NULL}}; static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { From 9be1ac77ef334c28482ec5c9476289304f34128e Mon Sep 17 00:00:00 2001 From: Devendar Bureddy Date: Wed, 26 Feb 2025 01:53:50 -0800 Subject: [PATCH 11/11] TL/SHARP: Option to enable SHARP multi-channel (#1049) --- src/components/tl/sharp/tl_sharp.c | 7 ++++++- src/components/tl/sharp/tl_sharp.h | 3 ++- src/components/tl/sharp/tl_sharp_context.c | 6 +++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/components/tl/sharp/tl_sharp.c b/src/components/tl/sharp/tl_sharp.c index 464ef50478..fe86950bf7 100644 --- a/src/components/tl/sharp/tl_sharp.c +++ b/src/components/tl/sharp/tl_sharp.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -72,6 +72,11 @@ static ucc_config_field_t ucc_tl_sharp_context_config_table[] = { ucc_offsetof(ucc_tl_sharp_context_config_t, team_max_ppn), UCC_CONFIG_TYPE_UINT}, + {"USE_MULTI_CHANNEL", "0", + "Use SHARP Multi-channel feature. Options: 0-disable 1-enable", + ucc_offsetof(ucc_tl_sharp_context_config_t, use_multi_channel), + UCC_CONFIG_TYPE_BOOL}, + {NULL}}; UCC_CLASS_DEFINE_NEW_FUNC(ucc_tl_sharp_lib_t, ucc_base_lib_t, diff --git a/src/components/tl/sharp/tl_sharp.h b/src/components/tl/sharp/tl_sharp.h index adfbc86036..875b9c6689 100644 --- a/src/components/tl/sharp/tl_sharp.h +++ b/src/components/tl/sharp/tl_sharp.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -53,6 +53,7 @@ typedef struct ucc_tl_sharp_context_config { int context_per_team; int enable_lazy_group_alloc; int team_max_ppn; + int use_multi_channel; } ucc_tl_sharp_context_config_t; typedef struct ucc_tl_sharp_lib { diff --git a/src/components/tl/sharp/tl_sharp_context.c b/src/components/tl/sharp/tl_sharp_context.c index 42d10f8d87..ed7d50578b 100644 --- a/src/components/tl/sharp/tl_sharp_context.c +++ b/src/components/tl/sharp/tl_sharp_context.c @@ -305,7 +305,11 @@ ucc_status_t ucc_tl_sharp_context_init(ucc_tl_sharp_context_t *sharp_ctx, init_spec.progress_func = NULL; init_spec.world_local_rank = local_rank; - init_spec.group_channel_idx = 0; + if (sharp_ctx->cfg.use_multi_channel) { + init_spec.group_channel_idx = local_rank; + } else { + init_spec.group_channel_idx = 0; + } init_spec.oob_ctx = oob_ctx; init_spec.config = sharp_coll_default_config; init_spec.config.user_progress_num_polls = sharp_ctx->cfg.uprogress_num_polls;