Skip to content

Commit

Permalink
TL/MLX5: Stabilize staging-based AG, Add zcopy AG
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Mar 1, 2025
1 parent d6e7ec2 commit b4cd6e1
Show file tree
Hide file tree
Showing 15 changed files with 590 additions and 214 deletions.
2 changes: 2 additions & 0 deletions src/components/tl/mlx5/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ mcast = \
mcast/tl_mlx5_mcast_service_coll.c \
mcast/tl_mlx5_mcast_one_sided_reliability.h \
mcast/tl_mlx5_mcast_one_sided_reliability.c \
mcast/tl_mlx5_mcast_one_sided_progress.h \
mcast/tl_mlx5_mcast_one_sided_progress.c \
mcast/tl_mlx5_mcast_allgather.h \
mcast/tl_mlx5_mcast_allgather.c \
mcast/tl_mlx5_mcast_team.c
Expand Down
11 changes: 7 additions & 4 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ enum {
MCAST_CALC_WR,
MCAST_BCASTRECV_WR,
MCAST_BCASTSEND_WR,
MCAST_AG_RDMA_READ_INFO_WR,
MCAST_AG_RDMA_READ_WR,
};

struct ucc_tl_mlx5_mcast_p2p_completion_obj;
Expand Down Expand Up @@ -117,6 +119,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm_init_spec {
int max_eager;
int cuda_mem_enabled;
int one_sided_reliability_enable;
int reliability_scheme_msg_threshold;
int truly_zero_copy_allgather_enabled;
int mcast_prepost_bucket_size;
void *oob;
Expand Down Expand Up @@ -218,8 +221,6 @@ struct mcast_ctx {
struct mcast_group groups[MAX_GROUP_COUNT];
// RC connection info for supporing one-sided based relibality
struct ibv_qp **rc_qp;
uint16_t *rc_lid;
union ibv_gid *rc_gid;
};

struct packet {
Expand Down Expand Up @@ -265,7 +266,7 @@ typedef struct ucc_tl_mlx5_mcast_one_sided_reliability_comm {
ucc_service_coll_req_t *reliability_req;
int reliability_enabled;
int reliability_ready;
int rdma_read_in_progress;
int pending_reads;
enum ucc_tl_mlx5_mcast_one_sided_slot_states slots_state;
} ucc_tl_mlx5_mcast_one_sided_reliability_comm_t;

Expand Down Expand Up @@ -446,6 +447,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_req {
void *recv_rreg;
ucc_ee_executor_task_t *exec_task;
ucc_coll_task_t *coll_task;
ucc_status_t (*progress) (void *req);
} ucc_tl_mlx5_mcast_coll_req_t;

typedef struct ucc_tl_mlx5_mcast_oob_p2p_context {
Expand Down Expand Up @@ -477,7 +479,8 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast
int count = comm->params.rx_depth - comm->pending_recv;
int i;

if (count <= comm->params.post_recv_thresh) {
if (comm->allgather_comm.truly_zero_copy_allgather_enabled ||
count <= comm->params.post_recv_thresh) {
return UCC_OK;
}

Expand Down
Loading

0 comments on commit b4cd6e1

Please sign in to comment.