From 21898d1618c53cac4fef077d30f0582bd0b35b64 Mon Sep 17 00:00:00 2001 From: Lindsay Reiser Date: Mon, 24 Feb 2025 16:53:13 -0500 Subject: [PATCH] prov/opx: Move CUDA sync attribute setting to mr registration Signed-off-by: Lindsay Reiser --- prov/opx/src/fi_opx_hfi1.c | 20 -------------------- prov/opx/src/fi_opx_mr.c | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/prov/opx/src/fi_opx_hfi1.c b/prov/opx/src/fi_opx_hfi1.c index 64f56535ac8..fe633fdf8e5 100644 --- a/prov/opx/src/fi_opx_hfi1.c +++ b/prov/opx/src/fi_opx_hfi1.c @@ -1206,16 +1206,6 @@ int opx_hfi1_rx_rzv_rts_send_cts(union fi_opx_hfi1_deferred_work *work) } } -#ifdef HAVE_CUDA - if (params->dput_iov[0].rbuf_iface == FI_HMEM_CUDA) { - int err = cuda_set_sync_memops((void *) params->dput_iov[0].rbuf); - if (OFI_UNLIKELY(err != 0)) { - FI_WARN(fi_opx_global.prov, FI_LOG_MR, "cuda_set_sync_memops(%p) FAILED (returned %d)\n", - (void *) params->dput_iov[0].rbuf, err); - } - } -#endif - fi_opx_reliability_service_do_replay(&opx_ep->reliability->service, replay); fi_opx_reliability_client_replay_register_no_update(&opx_ep->reliability->state, params->origin_rx, psn_ptr, replay, params->reliability, OPX_HFI1_TYPE); @@ -2504,16 +2494,6 @@ int opx_hfi1_rx_rma_rts_send_cts(union fi_opx_hfi1_deferred_work *work) tx_payload->cts.iov[i] = params->dput_iov[i]; } -#ifdef HAVE_CUDA - if (params->dput_iov[0].rbuf_iface == FI_HMEM_CUDA) { - int err = cuda_set_sync_memops((void *) params->dput_iov[0].rbuf); - if (OFI_UNLIKELY(err != 0)) { - FI_WARN(fi_opx_global.prov, FI_LOG_MR, "cuda_set_sync_memops(%p) FAILED (returned %d)\n", - (void *) params->dput_iov[0].rbuf, err); - } - } -#endif - fi_opx_reliability_service_do_replay(&opx_ep->reliability->service, replay); fi_opx_reliability_client_replay_register_no_update(&opx_ep->reliability->state, params->origin_rx, psn_ptr, replay, params->reliability, hfi1_type); diff --git a/prov/opx/src/fi_opx_mr.c b/prov/opx/src/fi_opx_mr.c index 6a964ccdc75..e0dba13290f 100644 --- a/prov/opx/src/fi_opx_mr.c +++ b/prov/opx/src/fi_opx_mr.c @@ -174,6 +174,16 @@ static inline int fi_opx_mr_reg_internal(struct fid *fid, const struct iovec *io } opx_mr->attr.requested_key = opx_mr->mr_fid.key; +#ifdef HAVE_CUDA + if (hmem_iface == FI_HMEM_CUDA) { + int err = cuda_set_sync_memops((void *) iov->iov_base); + if (OFI_UNLIKELY(err != 0)) { + FI_WARN(fi_opx_global.prov, FI_LOG_MR, + "cuda_set_sync_memops(%p) FAILED (returned %d)\n", + (void *) iov->iov_base, err); + } + } +#endif if (opx_mr->domain->mr_mode & OFI_MR_SCALABLE) { fi_opx_ref_inc(&opx_mr->domain->ref_cnt, "domain"); } @@ -198,6 +208,11 @@ static inline int fi_opx_mr_reg_internal(struct fid *fid, const struct iovec *io switch (hmem_iface) { case FI_HMEM_CUDA: opx_mr->attr.device.cuda = (int) hmem_device; + int err = cuda_set_sync_memops((void *) iov->iov_base); + if (OFI_UNLIKELY(err != 0)) { + FI_WARN(fi_opx_global.prov, FI_LOG_MR, "cuda_set_sync_memops(%p) FAILED (returned %d)\n", + (void *) iov->iov_base, err); + } break; case FI_HMEM_ZE: opx_mr->attr.device.ze = (int) hmem_device;