diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 36f11cff9..25589e4a8 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -1581,6 +1581,26 @@ static int post_flush_req(nccl_net_ofi_rdma_req_t *req); static int post_eager_copy(nccl_net_ofi_rdma_req_t *req); + +static nccl_net_ofi_rdma_req_t *rdma_op_context_get_req(void *op_context, int rail_id) +{ + struct fi_context2 *ctx = (struct fi_context2 *)op_context; + if (OFI_UNLIKELY(ctx == NULL)) { + return NULL; + } + + /* To find the request, we need to find the + * start of the context array. Since the + * sender will always use its rail_id for the + * ctx array index, we can do the same. + */ + ctx -= rail_id; + return container_of(ctx, + nccl_net_ofi_rdma_req_t, + ctx); +} + + /* * @brief Processes completion entries from CQ * @@ -1615,20 +1635,11 @@ static inline int process_completions(struct fi_cq_data_entry *cq_entry, uint64_ /* Remote-initiated write is complete */ ret = handle_write_comp(&cq_entry[comp_idx], device, rail_id); } else { - struct fi_context2 *ctx = (struct fi_context2 *)cq_entry[comp_idx].op_context; - if (OFI_UNLIKELY(ctx == NULL)) { + req = rdma_op_context_get_req(cq_entry[comp_idx].op_context, rail_id); + if (OFI_UNLIKELY(req == NULL)) { NCCL_OFI_WARN("Completion with unexpected NULL op_context"); return -EINVAL; } - /* To find the request, we need to find the - * start of the context array. Since the - * sender will always use its rail_id for the - * ctx array index, we can do the same. - */ - ctx -= rail_id; - req = container_of(ctx, - nccl_net_ofi_rdma_req_t, - ctx); if (comp_flags & FI_SEND) { /* Send completions */ @@ -1751,11 +1762,12 @@ static inline int process_completions(struct fi_cq_data_entry *cq_entry, uint64_ * error, on others */ static inline int process_err_completion(nccl_net_ofi_rdma_device_t *device, - struct fid_cq *cq) + nccl_net_ofi_ep_rail_t *rail) { struct fi_cq_err_entry err_entry = {}; nccl_net_ofi_rdma_req_t *req = NULL; int ret = 0; + struct fid_cq *cq = rail->cq; ret = fi_cq_readerr(cq, &err_entry, 0); if (OFI_UNLIKELY(ret == -FI_EAGAIN)) { @@ -1796,7 +1808,7 @@ static inline int process_err_completion(nccl_net_ofi_rdma_device_t *device, ret = -EIO; goto exit; } - req = (nccl_net_ofi_rdma_req_t *)err_entry.op_context; + req = rdma_op_context_get_req(err_entry.op_context, rail->rail_id); } NCCL_OFI_WARN("Request %p completed with error. RC: %d. Error: %d (%s). Completed length: %ld, Request: %s", @@ -1996,7 +2008,7 @@ static int ofi_process_cq_rail(nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_ep_rail_ if (OFI_UNLIKELY(ret != 0)) goto exit; } else if (OFI_UNLIKELY(rc == -FI_EAVAIL)) { - ret = process_err_completion(rdma_endpoint_get_device(ep), rail->cq); + ret = process_err_completion(rdma_endpoint_get_device(ep), rail); if (ret == 0) { /* Error entry not available yet */ break;