Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rdma: Fix request lookup on error #800

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions src/nccl_ofi_rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -1581,6 +1581,26 @@ static int post_flush_req(nccl_net_ofi_rdma_req_t *req);

static int post_eager_copy(nccl_net_ofi_rdma_req_t *req);


static nccl_net_ofi_rdma_req_t *rdma_op_context_get_req(void *op_context, int rail_id)
{
struct fi_context2 *ctx = (struct fi_context2 *)op_context;
if (OFI_UNLIKELY(ctx == NULL)) {
return NULL;
}

/* To find the request, we need to find the
* start of the context array. Since the
* sender will always use its rail_id for the
* ctx array index, we can do the same.
*/
ctx -= rail_id;
return container_of(ctx,
nccl_net_ofi_rdma_req_t,
ctx);
}


/*
* @brief Processes completion entries from CQ
*
Expand Down Expand Up @@ -1615,20 +1635,11 @@ static inline int process_completions(struct fi_cq_data_entry *cq_entry, uint64_
/* Remote-initiated write is complete */
ret = handle_write_comp(&cq_entry[comp_idx], device, rail_id);
} else {
struct fi_context2 *ctx = (struct fi_context2 *)cq_entry[comp_idx].op_context;
if (OFI_UNLIKELY(ctx == NULL)) {
req = rdma_op_context_get_req(cq_entry[comp_idx].op_context, rail_id);
if (OFI_UNLIKELY(req == NULL)) {
NCCL_OFI_WARN("Completion with unexpected NULL op_context");
return -EINVAL;
}
/* To find the request, we need to find the
* start of the context array. Since the
* sender will always use its rail_id for the
* ctx array index, we can do the same.
*/
ctx -= rail_id;
req = container_of(ctx,
nccl_net_ofi_rdma_req_t,
ctx);

if (comp_flags & FI_SEND) {
/* Send completions */
Expand Down Expand Up @@ -1751,11 +1762,12 @@ static inline int process_completions(struct fi_cq_data_entry *cq_entry, uint64_
* error, on others
*/
static inline int process_err_completion(nccl_net_ofi_rdma_device_t *device,
struct fid_cq *cq)
nccl_net_ofi_ep_rail_t *rail)
{
struct fi_cq_err_entry err_entry = {};
nccl_net_ofi_rdma_req_t *req = NULL;
int ret = 0;
struct fid_cq *cq = rail->cq;

ret = fi_cq_readerr(cq, &err_entry, 0);
if (OFI_UNLIKELY(ret == -FI_EAGAIN)) {
Expand Down Expand Up @@ -1796,7 +1808,7 @@ static inline int process_err_completion(nccl_net_ofi_rdma_device_t *device,
ret = -EIO;
goto exit;
}
req = (nccl_net_ofi_rdma_req_t *)err_entry.op_context;
req = rdma_op_context_get_req(err_entry.op_context, rail->rail_id);
}

NCCL_OFI_WARN("Request %p completed with error. RC: %d. Error: %d (%s). Completed length: %ld, Request: %s",
Expand Down Expand Up @@ -1996,7 +2008,7 @@ static int ofi_process_cq_rail(nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_ep_rail_
if (OFI_UNLIKELY(ret != 0))
goto exit;
} else if (OFI_UNLIKELY(rc == -FI_EAVAIL)) {
ret = process_err_completion(rdma_endpoint_get_device(ep), rail->cq);
ret = process_err_completion(rdma_endpoint_get_device(ep), rail);
if (ret == 0) {
/* Error entry not available yet */
break;
Expand Down
Loading