diff --git a/include/rdma/fabric.h b/include/rdma/fabric.h index 42c50532797..2059fcbbdd2 100644 --- a/include/rdma/fabric.h +++ b/include/rdma/fabric.h @@ -346,6 +346,7 @@ enum { FI_TAG_BITS, FI_TAG_MPI, FI_TAG_CCL, + FI_TAG_RPC, FI_TAG_MAX_FORMAT = (1ULL << 16), }; diff --git a/man/fi_endpoint.3.md b/man/fi_endpoint.3.md index 973ec8e73e4..6a210cf1b89 100644 --- a/man/fi_endpoint.3.md +++ b/man/fi_endpoint.3.md @@ -919,6 +919,17 @@ wire protocols. The following tag formats are defined: Applications that use the CCL format pass in the payload identifier directly as the tag and set ignore bits to 0. +*FI_TAG_RPC* + +: The FI_TAG_RPC flag is used to indicate that tags are being utilized to match + RPC requests and replies. When specified via fi_getinfo, the caller ensures that + a reply buffer with the corresponding tag is registered when sending a request. + + This mechanism enables libfabric to identify and discard stale replies, preventing + them from interfering with new communications. This is crucial to avoid blocking + a restarting endpoint, which may otherwise lack sufficient metadata to process + incoming messages with unmatched tags. + *FI_TAG_MAX_FORMAT* : If the value of mem_tag_format is >= FI_TAG_MAX_FORMAT, the tag format is treated as a set of bit fields. The behavior is functionally the same diff --git a/prov/tcp/src/xnet.h b/prov/tcp/src/xnet.h index 486af642d47..d6dffc8c930 100644 --- a/prov/tcp/src/xnet.h +++ b/prov/tcp/src/xnet.h @@ -264,6 +264,7 @@ struct xnet_ep { void (*hdr_bswap)(struct xnet_ep *ep, struct xnet_base_hdr *hdr); short pollflags; + bool tagged_rpc; xnet_profile_t *profile; }; @@ -428,6 +429,7 @@ static inline void xnet_signal_progress(struct xnet_progress *progress) #define XNET_COPY_RECV BIT(9) #define XNET_CLAIM_RECV BIT(10) #define XNET_NEED_CTS BIT(11) +#define XNET_UNEXP_XFER BIT(12) #define XNET_MULTI_RECV FI_MULTI_RECV /* BIT(16) */ struct xnet_mrecv { diff --git a/prov/tcp/src/xnet_cq.c b/prov/tcp/src/xnet_cq.c index 2090bdf7170..294a3fe73d7 100644 --- a/prov/tcp/src/xnet_cq.c +++ b/prov/tcp/src/xnet_cq.c @@ -130,7 +130,7 @@ void xnet_report_success(struct xnet_xfer_entry *xfer_entry) uint64_t flags, data, tag; size_t len; - if (xfer_entry->ctrl_flags & (XNET_INTERNAL_XFER | XNET_SAVED_XFER)) + if (xfer_entry->ctrl_flags & (XNET_INTERNAL_XFER | XNET_SAVED_XFER | XNET_UNEXP_XFER)) return; if (xfer_entry->cntr) diff --git a/prov/tcp/src/xnet_ep.c b/prov/tcp/src/xnet_ep.c index 0ff5723d9d2..b6a7a60fd62 100644 --- a/prov/tcp/src/xnet_ep.c +++ b/prov/tcp/src/xnet_ep.c @@ -283,6 +283,7 @@ xnet_ep_accept(struct fid_ep *ep_fid, const void *param, size_t paramlen) (paramlen > XNET_MAX_CM_DATA_SIZE)) return -FI_EINVAL; + ep->tagged_rpc = conn->pep->info->ep_attr->mem_tag_format == FI_TAG_RPC; ep->conn = NULL; assert(ep->cm_msg); diff --git a/prov/tcp/src/xnet_init.c b/prov/tcp/src/xnet_init.c index 0805ad94088..fd1f096223f 100644 --- a/prov/tcp/src/xnet_init.c +++ b/prov/tcp/src/xnet_init.c @@ -48,8 +48,16 @@ static int xnet_getinfo(uint32_t version, const char *node, const char *service, uint64_t flags, const struct fi_info *hints, struct fi_info **info) { - return ofi_ip_getinfo(&xnet_util_prov, version, node, service, flags, - hints, info); + int ret; + + ret = ofi_ip_getinfo(&xnet_util_prov, version, node, service, flags, hints, info); + if (ret) + return ret; + + if (hints->ep_attr && hints->ep_attr->mem_tag_format && (*info)->ep_attr) + (*info)->ep_attr->mem_tag_format = hints->ep_attr->mem_tag_format; + + return 0; } struct xnet_port_range xnet_ports = { diff --git a/prov/tcp/src/xnet_progress.c b/prov/tcp/src/xnet_progress.c index aa76968e175..f1ac38ddad6 100644 --- a/prov/tcp/src/xnet_progress.c +++ b/prov/tcp/src/xnet_progress.c @@ -103,6 +103,44 @@ static bool xnet_save_and_cont(struct xnet_ep *ep) return (ep->saved_msg->cnt < xnet_max_saved); } +static struct xnet_xfer_entry * +xnet_get_unexp_rx(struct xnet_ep *ep, uint64_t tag) +{ + struct xnet_progress *progress; + struct xnet_xfer_entry *rx_entry; + + progress = xnet_ep2_progress(ep); + assert(xnet_progress_locked(progress)); + assert(ep->cur_rx.hdr_done == ep->cur_rx.hdr_len && + !ep->cur_rx.claim_ctx); + + FI_DBG(&xnet_prov, FI_LOG_EP_DATA, "Unexp msg tag 0x%zx src %zu\n", + tag, ep->peer->fi_addr); + rx_entry = xnet_alloc_xfer(xnet_srx2_progress(ep->srx)); + if (!rx_entry) + return NULL; + + rx_entry->saving_ep = NULL; + rx_entry->tag = tag; + rx_entry->ignore = 0; + rx_entry->ctrl_flags = XNET_UNEXP_XFER; + + if (ep->cur_rx.data_left <= xnet_buf_size) { + rx_entry->user_buf = NULL; + rx_entry->iov[0].iov_base = &rx_entry->msg_data; + rx_entry->iov[0].iov_len = xnet_buf_size; + rx_entry->iov_cnt = 1; + } else if (xnet_alloc_xfer_buf(rx_entry, ep->cur_rx.data_left)) { + goto free_xfer; + } + + return rx_entry; + +free_xfer: + xnet_free_xfer(progress, rx_entry); + return NULL; +} + static struct xnet_xfer_entry * xnet_get_save_rx(struct xnet_ep *ep, uint64_t tag) { @@ -822,6 +860,14 @@ static int xnet_handle_tag(struct xnet_ep *ep) if (rx_entry) return xnet_start_recv(ep, rx_entry); } + + if (ep->tagged_rpc) { + /* receive and discard this unexpected message for tagged rpc */ + rx_entry = xnet_get_unexp_rx(ep, tag); + if (rx_entry) + return xnet_start_recv(ep, rx_entry); + } + if (dlist_empty(&ep->unexp_entry)) { dlist_insert_tail(&ep->unexp_entry, &xnet_ep2_progress(ep)->unexp_tag_list); @@ -1102,7 +1148,7 @@ static void xnet_complete_rx(struct xnet_ep *ep, ssize_t ret) goto cq_error; } - if (!(rx_entry->ctrl_flags & XNET_SAVED_XFER)) { + if (!(rx_entry->ctrl_flags & XNET_SAVED_XFER) || (rx_entry->ctrl_flags & XNET_UNEXP_XFER)) { xnet_report_success(rx_entry); xnet_free_xfer(xnet_ep2_progress(ep), rx_entry); } else { diff --git a/prov/tcp/src/xnet_rdm.c b/prov/tcp/src/xnet_rdm.c index 77456860569..a18426dc012 100644 --- a/prov/tcp/src/xnet_rdm.c +++ b/prov/tcp/src/xnet_rdm.c @@ -1044,6 +1044,8 @@ static int xnet_init_rdm(struct xnet_rdm *rdm, struct fi_info *info) msg_info->tx_attr->op_flags = info->tx_attr->op_flags; msg_info->rx_attr->caps &= info->rx_attr->caps; msg_info->rx_attr->op_flags = info->rx_attr->op_flags; + if (info->ep_attr) + msg_info->ep_attr->mem_tag_format = info->ep_attr->mem_tag_format; ret = fi_srx_context(&rdm->util_ep.domain->domain_fid, info->rx_attr, &srx, rdm);