From c414da3bc26c3d9dc510478bf496a68d75a15fff Mon Sep 17 00:00:00 2001 From: Joseph Antony Date: Tue, 31 Mar 2026 16:46:27 -0400 Subject: [PATCH] UCX Notified Communication Signed-off-by: Joseph Antony --- ompi/mca/osc/ucx/osc_ucx.h | 28 ++++ ompi/mca/osc/ucx/osc_ucx_comm.c | 192 +++++++++++++++++++++++++++ ompi/mca/osc/ucx/osc_ucx_component.c | 34 ++++- 3 files changed, 250 insertions(+), 4 deletions(-) diff --git a/ompi/mca/osc/ucx/osc_ucx.h b/ompi/mca/osc/ucx/osc_ucx.h index bc3dc8a91b3..0266f930ec3 100644 --- a/ompi/mca/osc/ucx/osc_ucx.h +++ b/ompi/mca/osc/ucx/osc_ucx.h @@ -27,6 +27,7 @@ #define OMPI_OSC_UCX_POST_PEER_MAX 32 #define OMPI_OSC_UCX_ATTACH_MAX 48 #define OMPI_OSC_UCX_MEM_ADDR_MAX_LEN 1024 +#define OMPI_OSC_UCX_MAX_NOTIFY_COUNTERS 16 typedef struct ompi_osc_ucx_component { @@ -277,6 +278,33 @@ int ompi_osc_find_attached_region_position(ompi_osc_dynamic_win_info_t *dynamic_ int ompi_osc_ucx_dynamic_lock(ompi_osc_ucx_module_t *module, int target); int ompi_osc_ucx_dynamic_unlock(ompi_osc_ucx_module_t *module, int target); +int ompi_osc_ucx_put_notify(const void *origin_addr, size_t origin_count, + struct ompi_datatype_t *origin_dt, + int target, ptrdiff_t target_disp, size_t target_count, + struct ompi_datatype_t *target_dt, + int notify, struct ompi_win_t *win); +int ompi_osc_ucx_get_notify(void *origin_addr, size_t origin_count, + struct ompi_datatype_t *origin_dt, + int target, ptrdiff_t target_disp, size_t target_count, + struct ompi_datatype_t *target_dt, + int notify, struct ompi_win_t *win); +int ompi_osc_ucx_rput_notify(const void *origin_addr, size_t origin_count, + struct ompi_datatype_t *origin_dt, + int target, ptrdiff_t target_disp, size_t target_count, + struct ompi_datatype_t *target_dt, + int notify, struct ompi_win_t *win, + struct ompi_request_t **request); +int ompi_osc_ucx_rget_notify(void *origin_addr, size_t origin_count, + struct ompi_datatype_t *origin_dt, + int target, ptrdiff_t target_disp, size_t target_count, + struct ompi_datatype_t *target_dt, + int notify, struct ompi_win_t *win, + struct ompi_request_t **request); +int ompi_osc_ucx_win_get_notify_value(struct ompi_win_t *win, int notify, + OMPI_MPI_COUNT_TYPE *value); +int ompi_osc_ucx_win_reset_notify_value(struct ompi_win_t *win, int notify, + OMPI_MPI_COUNT_TYPE *value); + /* returns the size at the peer */ static inline size_t ompi_osc_ucx_get_size(ompi_osc_ucx_module_t *module, int rank) { diff --git a/ompi/mca/osc/ucx/osc_ucx_comm.c b/ompi/mca/osc/ucx/osc_ucx_comm.c index 0354edb71c0..53d1cf23cfd 100644 --- a/ompi/mca/osc/ucx/osc_ucx_comm.c +++ b/ompi/mca/osc/ucx/osc_ucx_comm.c @@ -603,6 +603,198 @@ int ompi_osc_ucx_get(void *origin_addr, size_t origin_count, } } +/* Returns the remote address of notify counter[notify] for the given target. + * Counters are appended directly after the target's window data in the same + * registered memory region (module->mem), so the rkey that covers window data + * also covers the counters. */ +static inline uint64_t +osc_ucx_notify_counter_addr(ompi_osc_ucx_module_t *module, int target, int notify) +{ + return module->addrs[target] + + ompi_osc_ucx_get_size(module, target) + + (uint64_t)notify * sizeof(uint64_t); +} + +#define CHECK_NOTIFY_IDX(notify) \ + if ((notify) < 0 || (notify) >= OMPI_OSC_UCX_MAX_NOTIFY_COUNTERS) { \ + return MPI_ERR_NOTIFY_IDX; \ + } + +int ompi_osc_ucx_win_get_notify_value(struct ompi_win_t *win, int notify, + OMPI_MPI_COUNT_TYPE *value) +{ + ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module; + int my_rank = ompi_comm_rank(module->comm); + + CHECK_NOTIFY_IDX(notify); + + /* Counters are local memory — just read with a barrier to ensure + * any preceding remote writes to this counter are visible. */ + opal_atomic_rmb(); + volatile uint64_t *counter = + (volatile uint64_t *)(module->addrs[my_rank] + module->size) + notify; + *value = (OMPI_MPI_COUNT_TYPE)*counter; + return OMPI_SUCCESS; +} + +int ompi_osc_ucx_win_reset_notify_value(struct ompi_win_t *win, int notify, + OMPI_MPI_COUNT_TYPE *value) +{ + ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module; + int my_rank = ompi_comm_rank(module->comm); + + CHECK_NOTIFY_IDX(notify); + + volatile uint64_t *counter = + (volatile uint64_t *)(module->addrs[my_rank] + module->size) + notify; + *value = (OMPI_MPI_COUNT_TYPE)opal_atomic_swap_64((volatile int64_t *)counter, 0); + return OMPI_SUCCESS; +} + +int ompi_osc_ucx_put_notify(const void *origin_addr, size_t origin_count, + struct ompi_datatype_t *origin_dt, + int target, ptrdiff_t target_disp, size_t target_count, + struct ompi_datatype_t *target_dt, + int notify, struct ompi_win_t *win) +{ + ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module; + ucp_ep_h *ep; + int ret; + + CHECK_NOTIFY_IDX(notify); + + OSC_UCX_GET_DEFAULT_EP(ep, module, target); + + ret = ompi_osc_ucx_put(origin_addr, origin_count, origin_dt, + target, target_disp, target_count, target_dt, win); + if (OMPI_SUCCESS != ret) { + return ret; + } + + /* Flush to ensure the PUT is visible at the target before the counter + * increment arrives. */ + ret = opal_common_ucx_wpmem_fence(module->mem); + if (OPAL_SUCCESS != ret) { + return OMPI_ERROR; + } + + /* Atomically increment the target's notify counter in-place using the + * same mem handle as the window data. */ + ret = opal_common_ucx_wpmem_post(module->mem, + UCP_ATOMIC_POST_OP_ADD, 1, + target, sizeof(uint64_t), + osc_ucx_notify_counter_addr(module, target, notify), + ep); + return (OPAL_SUCCESS == ret) ? OMPI_SUCCESS : OMPI_ERROR; +} + +int ompi_osc_ucx_get_notify(void *origin_addr, size_t origin_count, + struct ompi_datatype_t *origin_dt, + int target, ptrdiff_t target_disp, size_t target_count, + struct ompi_datatype_t *target_dt, + int notify, struct ompi_win_t *win) +{ + ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module; + ucp_ep_h *ep; + int ret; + + CHECK_NOTIFY_IDX(notify); + + OSC_UCX_GET_DEFAULT_EP(ep, module, target); + + ret = ompi_osc_ucx_get(origin_addr, origin_count, origin_dt, + target, target_disp, target_count, target_dt, win); + if (OMPI_SUCCESS != ret) { + return ret; + } + + /* Flush to ensure the GET data is locally available before issuing the + * counter increment back to the target. */ + ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP, target); + if (OPAL_SUCCESS != ret) { + return OMPI_ERROR; + } + + ret = opal_common_ucx_wpmem_post(module->mem, + UCP_ATOMIC_POST_OP_ADD, 1, + target, sizeof(uint64_t), + osc_ucx_notify_counter_addr(module, target, notify), + ep); + return (OPAL_SUCCESS == ret) ? OMPI_SUCCESS : OMPI_ERROR; +} + +int ompi_osc_ucx_rput_notify(const void *origin_addr, size_t origin_count, + struct ompi_datatype_t *origin_dt, + int target, ptrdiff_t target_disp, size_t target_count, + struct ompi_datatype_t *target_dt, + int notify, struct ompi_win_t *win, + struct ompi_request_t **request) +{ + ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module; + ucp_ep_h *ep; + int ret; + + CHECK_NOTIFY_IDX(notify); + + OSC_UCX_GET_DEFAULT_EP(ep, module, target); + + ret = ompi_osc_ucx_rput(origin_addr, origin_count, origin_dt, + target, target_disp, target_count, target_dt, + win, request); + if (OMPI_SUCCESS != ret) { + return ret; + } + + /* Fence to order the PUT before the counter increment. */ + ret = opal_common_ucx_wpmem_fence(module->mem); + if (OPAL_SUCCESS != ret) { + return OMPI_ERROR; + } + + ret = opal_common_ucx_wpmem_post(module->mem, + UCP_ATOMIC_POST_OP_ADD, 1, + target, sizeof(uint64_t), + osc_ucx_notify_counter_addr(module, target, notify), + ep); + return (OPAL_SUCCESS == ret) ? OMPI_SUCCESS : OMPI_ERROR; +} + +int ompi_osc_ucx_rget_notify(void *origin_addr, size_t origin_count, + struct ompi_datatype_t *origin_dt, + int target, ptrdiff_t target_disp, size_t target_count, + struct ompi_datatype_t *target_dt, + int notify, struct ompi_win_t *win, + struct ompi_request_t **request) +{ + ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module; + ucp_ep_h *ep; + int ret; + + CHECK_NOTIFY_IDX(notify); + + OSC_UCX_GET_DEFAULT_EP(ep, module, target); + + ret = ompi_osc_ucx_rget(origin_addr, origin_count, origin_dt, + target, target_disp, target_count, target_dt, + win, request); + if (OMPI_SUCCESS != ret) { + return ret; + } + + /* Flush to ensure GET data is locally available before notifying target. */ + ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP, target); + if (OPAL_SUCCESS != ret) { + return OMPI_ERROR; + } + + ret = opal_common_ucx_wpmem_post(module->mem, + UCP_ATOMIC_POST_OP_ADD, 1, + target, sizeof(uint64_t), + osc_ucx_notify_counter_addr(module, target, notify), + ep); + return (OPAL_SUCCESS == ret) ? OMPI_SUCCESS : OMPI_ERROR; +} + static inline bool ompi_osc_need_acc_lock(ompi_osc_ucx_module_t *module, int target) { ompi_osc_ucx_lock_t *lock = NULL; diff --git a/ompi/mca/osc/ucx/osc_ucx_component.c b/ompi/mca/osc/ucx/osc_ucx_component.c index 635a53a3e0f..fcf9d23aee4 100644 --- a/ompi/mca/osc/ucx/osc_ucx_component.c +++ b/ompi/mca/osc/ucx/osc_ucx_component.c @@ -102,6 +102,13 @@ ompi_osc_ucx_module_t ompi_osc_ucx_module_template = { .osc_fetch_and_op = ompi_osc_ucx_fetch_and_op, .osc_get_accumulate = ompi_osc_ucx_get_accumulate, + .osc_put_notify = ompi_osc_ucx_put_notify, + .osc_get_notify = ompi_osc_ucx_get_notify, + .osc_rput_notify = ompi_osc_ucx_rput_notify, + .osc_rget_notify = ompi_osc_ucx_rget_notify, + .osc_win_get_notify_value = ompi_osc_ucx_win_get_notify_value, + .osc_win_reset_notify_value = ompi_osc_ucx_win_reset_notify_value, + .osc_rput = ompi_osc_ucx_rput, .osc_rget = ompi_osc_ucx_rget, .osc_raccumulate = ompi_osc_ucx_raccumulate, @@ -785,8 +792,10 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt /* create the segment */ size_t total = 0; + size_t notify_size = OMPI_OSC_UCX_MAX_NOTIFY_COUNTERS * sizeof(uint64_t); for (i = 0 ; i < comm_size ; ++i) { - total += ompi_osc_ucx_get_size(module, i); + /* each rank's slot holds its window data plus its notify counters */ + total += ompi_osc_ucx_get_size(module, i) + notify_size; } module->segment_base = NULL; @@ -849,14 +858,16 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt goto error; } - + /* Each rank's window slot is (peer_size + notify_size) bytes; the + * notify counters for rank i are at shmem_addrs[i] + peer_size. */ for (i = 0, total = 0; i < comm_size ; ++i) { size_t peer_size = ompi_osc_ucx_get_size(module, i); if (peer_size || !module->noncontig_shared_win) { module->shmem_addrs[i] = ((uint64_t) module->segment_base) + total; - total += peer_size; + total += peer_size + notify_size; } else { module->shmem_addrs[i] = (uint64_t)NULL; + total += notify_size; } } @@ -884,7 +895,16 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt ret = OMPI_ERR_BAD_PARAM; goto error; } - ret = opal_common_ucx_wpmem_create(module->ctx, mem_base, module->size, + /* Append notify counters after the window data in the same registered + * memory region. For ALLOCATE flavor the UCX allocator will hand back + * a buffer of this extended size; for CREATE/SHARED the user buffer is + * large enough to hold only the window data, but we still register the + * extra bytes so that remote atomic operations on the counters can use + * the same rkey as the window data. */ + size_t notify_reg_size = (flavor == MPI_WIN_FLAVOR_DYNAMIC) ? 0 : + OMPI_OSC_UCX_MAX_NOTIFY_COUNTERS * sizeof(uint64_t); + ret = opal_common_ucx_wpmem_create(module->ctx, mem_base, + module->size + notify_reg_size, mem_type, &exchange_len_info, OPAL_COMMON_UCX_WPMEM_ADDR_EXCHANGE_FULL, (void *)module->comm, @@ -957,6 +977,12 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt module->state.acc_lock = TARGET_LOCK_UNLOCKED; module->state.dynamic_lock = TARGET_LOCK_UNLOCKED; module->state.dynamic_win_count = 0; + + /* initialize notify counters to zero; they live at base + size */ + if (flavor != MPI_WIN_FLAVOR_DYNAMIC && *base != NULL) { + memset((char *)*base + module->size, 0, + OMPI_OSC_UCX_MAX_NOTIFY_COUNTERS * sizeof(uint64_t)); + } for (i = 0; i < OMPI_OSC_UCX_ATTACH_MAX; i++) { module->local_dynamic_win_info[i].refcnt = 0; }