Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions ompi/mca/osc/ucx/osc_ucx.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#define OMPI_OSC_UCX_POST_PEER_MAX 32
#define OMPI_OSC_UCX_ATTACH_MAX 48
#define OMPI_OSC_UCX_MEM_ADDR_MAX_LEN 1024
#define OMPI_OSC_UCX_MAX_NOTIFY_COUNTERS 16


typedef struct ompi_osc_ucx_component {
Expand Down Expand Up @@ -277,6 +278,33 @@ int ompi_osc_find_attached_region_position(ompi_osc_dynamic_win_info_t *dynamic_
int ompi_osc_ucx_dynamic_lock(ompi_osc_ucx_module_t *module, int target);
int ompi_osc_ucx_dynamic_unlock(ompi_osc_ucx_module_t *module, int target);

int ompi_osc_ucx_put_notify(const void *origin_addr, size_t origin_count,
struct ompi_datatype_t *origin_dt,
int target, ptrdiff_t target_disp, size_t target_count,
struct ompi_datatype_t *target_dt,
int notify, struct ompi_win_t *win);
int ompi_osc_ucx_get_notify(void *origin_addr, size_t origin_count,
struct ompi_datatype_t *origin_dt,
int target, ptrdiff_t target_disp, size_t target_count,
struct ompi_datatype_t *target_dt,
int notify, struct ompi_win_t *win);
int ompi_osc_ucx_rput_notify(const void *origin_addr, size_t origin_count,
struct ompi_datatype_t *origin_dt,
int target, ptrdiff_t target_disp, size_t target_count,
struct ompi_datatype_t *target_dt,
int notify, struct ompi_win_t *win,
struct ompi_request_t **request);
int ompi_osc_ucx_rget_notify(void *origin_addr, size_t origin_count,
struct ompi_datatype_t *origin_dt,
int target, ptrdiff_t target_disp, size_t target_count,
struct ompi_datatype_t *target_dt,
int notify, struct ompi_win_t *win,
struct ompi_request_t **request);
int ompi_osc_ucx_win_get_notify_value(struct ompi_win_t *win, int notify,
OMPI_MPI_COUNT_TYPE *value);
int ompi_osc_ucx_win_reset_notify_value(struct ompi_win_t *win, int notify,
OMPI_MPI_COUNT_TYPE *value);

/* returns the size at the peer */
static inline size_t ompi_osc_ucx_get_size(ompi_osc_ucx_module_t *module, int rank)
{
Expand Down
192 changes: 192 additions & 0 deletions ompi/mca/osc/ucx/osc_ucx_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,198 @@ int ompi_osc_ucx_get(void *origin_addr, size_t origin_count,
}
}

/* Returns the remote address of notify counter[notify] for the given target.
* Counters are appended directly after the target's window data in the same
* registered memory region (module->mem), so the rkey that covers window data
* also covers the counters. */
static inline uint64_t
osc_ucx_notify_counter_addr(ompi_osc_ucx_module_t *module, int target, int notify)
{
return module->addrs[target]
+ ompi_osc_ucx_get_size(module, target)
+ (uint64_t)notify * sizeof(uint64_t);
}

#define CHECK_NOTIFY_IDX(notify) \
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest passing in the module and checking against a value stored in there. OMPI_OSC_UCX_MAX_NOTIFY_COUNTERS is just a crutch until we have a better solution

if ((notify) < 0 || (notify) >= OMPI_OSC_UCX_MAX_NOTIFY_COUNTERS) { \
return MPI_ERR_NOTIFY_IDX; \
}

int ompi_osc_ucx_win_get_notify_value(struct ompi_win_t *win, int notify,
OMPI_MPI_COUNT_TYPE *value)
{
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module;
int my_rank = ompi_comm_rank(module->comm);

CHECK_NOTIFY_IDX(notify);

/* Counters are local memory — just read with a barrier to ensure
* any preceding remote writes to this counter are visible. */
opal_atomic_rmb();
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this actually does anything. Shouldn't the read barrier come after the counter read to prevent subsequent reads from being reordered?

volatile uint64_t *counter =
(volatile uint64_t *)(module->addrs[my_rank] + module->size) + notify;
*value = (OMPI_MPI_COUNT_TYPE)*counter;
return OMPI_SUCCESS;
}

int ompi_osc_ucx_win_reset_notify_value(struct ompi_win_t *win, int notify,
OMPI_MPI_COUNT_TYPE *value)
{
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module;
int my_rank = ompi_comm_rank(module->comm);

CHECK_NOTIFY_IDX(notify);

volatile uint64_t *counter =
(volatile uint64_t *)(module->addrs[my_rank] + module->size) + notify;
*value = (OMPI_MPI_COUNT_TYPE)opal_atomic_swap_64((volatile int64_t *)counter, 0);
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need to use UCX to swap the value, otherwise it is not guaranteed to be atomic wrt to other network atomic operations.

return OMPI_SUCCESS;
}

int ompi_osc_ucx_put_notify(const void *origin_addr, size_t origin_count,
struct ompi_datatype_t *origin_dt,
int target, ptrdiff_t target_disp, size_t target_count,
struct ompi_datatype_t *target_dt,
int notify, struct ompi_win_t *win)
{
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module;
ucp_ep_h *ep;
int ret;

CHECK_NOTIFY_IDX(notify);

OSC_UCX_GET_DEFAULT_EP(ep, module, target);

ret = ompi_osc_ucx_put(origin_addr, origin_count, origin_dt,
target, target_disp, target_count, target_dt, win);
if (OMPI_SUCCESS != ret) {
return ret;
}

/* Flush to ensure the PUT is visible at the target before the counter
* increment arrives. */
ret = opal_common_ucx_wpmem_fence(module->mem);
if (OPAL_SUCCESS != ret) {
return OMPI_ERROR;
}

/* Atomically increment the target's notify counter in-place using the
* same mem handle as the window data. */
ret = opal_common_ucx_wpmem_post(module->mem,
UCP_ATOMIC_POST_OP_ADD, 1,
target, sizeof(uint64_t),
osc_ucx_notify_counter_addr(module, target, notify),
ep);
return (OPAL_SUCCESS == ret) ? OMPI_SUCCESS : OMPI_ERROR;
}

int ompi_osc_ucx_get_notify(void *origin_addr, size_t origin_count,
struct ompi_datatype_t *origin_dt,
int target, ptrdiff_t target_disp, size_t target_count,
struct ompi_datatype_t *target_dt,
int notify, struct ompi_win_t *win)
{
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module;
ucp_ep_h *ep;
int ret;

CHECK_NOTIFY_IDX(notify);

OSC_UCX_GET_DEFAULT_EP(ep, module, target);

ret = ompi_osc_ucx_get(origin_addr, origin_count, origin_dt,
target, target_disp, target_count, target_dt, win);
if (OMPI_SUCCESS != ret) {
return ret;
}

/* Flush to ensure the GET data is locally available before issuing the
* counter increment back to the target. */
ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP, target);
if (OPAL_SUCCESS != ret) {
return OMPI_ERROR;
}

ret = opal_common_ucx_wpmem_post(module->mem,
UCP_ATOMIC_POST_OP_ADD, 1,
target, sizeof(uint64_t),
osc_ucx_notify_counter_addr(module, target, notify),
ep);
return (OPAL_SUCCESS == ret) ? OMPI_SUCCESS : OMPI_ERROR;
}

int ompi_osc_ucx_rput_notify(const void *origin_addr, size_t origin_count,
struct ompi_datatype_t *origin_dt,
int target, ptrdiff_t target_disp, size_t target_count,
struct ompi_datatype_t *target_dt,
int notify, struct ompi_win_t *win,
struct ompi_request_t **request)
{
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module;
ucp_ep_h *ep;
int ret;

CHECK_NOTIFY_IDX(notify);

OSC_UCX_GET_DEFAULT_EP(ep, module, target);

ret = ompi_osc_ucx_rput(origin_addr, origin_count, origin_dt,
target, target_disp, target_count, target_dt,
win, request);
if (OMPI_SUCCESS != ret) {
return ret;
}

/* Fence to order the PUT before the counter increment. */
ret = opal_common_ucx_wpmem_fence(module->mem);
if (OPAL_SUCCESS != ret) {
return OMPI_ERROR;
}

ret = opal_common_ucx_wpmem_post(module->mem,
UCP_ATOMIC_POST_OP_ADD, 1,
target, sizeof(uint64_t),
osc_ucx_notify_counter_addr(module, target, notify),
ep);
return (OPAL_SUCCESS == ret) ? OMPI_SUCCESS : OMPI_ERROR;
}

int ompi_osc_ucx_rget_notify(void *origin_addr, size_t origin_count,
struct ompi_datatype_t *origin_dt,
int target, ptrdiff_t target_disp, size_t target_count,
struct ompi_datatype_t *target_dt,
int notify, struct ompi_win_t *win,
struct ompi_request_t **request)
{
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module;
ucp_ep_h *ep;
int ret;

CHECK_NOTIFY_IDX(notify);

OSC_UCX_GET_DEFAULT_EP(ep, module, target);

ret = ompi_osc_ucx_rget(origin_addr, origin_count, origin_dt,
target, target_disp, target_count, target_dt,
win, request);
if (OMPI_SUCCESS != ret) {
return ret;
}

/* Flush to ensure GET data is locally available before notifying target. */
ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP, target);
if (OPAL_SUCCESS != ret) {
return OMPI_ERROR;
}

ret = opal_common_ucx_wpmem_post(module->mem,
UCP_ATOMIC_POST_OP_ADD, 1,
target, sizeof(uint64_t),
osc_ucx_notify_counter_addr(module, target, notify),
ep);
return (OPAL_SUCCESS == ret) ? OMPI_SUCCESS : OMPI_ERROR;
}

static inline bool ompi_osc_need_acc_lock(ompi_osc_ucx_module_t *module, int target)
{
ompi_osc_ucx_lock_t *lock = NULL;
Expand Down
34 changes: 30 additions & 4 deletions ompi/mca/osc/ucx/osc_ucx_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ ompi_osc_ucx_module_t ompi_osc_ucx_module_template = {
.osc_fetch_and_op = ompi_osc_ucx_fetch_and_op,
.osc_get_accumulate = ompi_osc_ucx_get_accumulate,

.osc_put_notify = ompi_osc_ucx_put_notify,
.osc_get_notify = ompi_osc_ucx_get_notify,
.osc_rput_notify = ompi_osc_ucx_rput_notify,
.osc_rget_notify = ompi_osc_ucx_rget_notify,
.osc_win_get_notify_value = ompi_osc_ucx_win_get_notify_value,
.osc_win_reset_notify_value = ompi_osc_ucx_win_reset_notify_value,

.osc_rput = ompi_osc_ucx_rput,
.osc_rget = ompi_osc_ucx_rget,
.osc_raccumulate = ompi_osc_ucx_raccumulate,
Expand Down Expand Up @@ -785,8 +792,10 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt
/* create the segment */

size_t total = 0;
size_t notify_size = OMPI_OSC_UCX_MAX_NOTIFY_COUNTERS * sizeof(uint64_t);
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace the use of OMPI_OSC_UCX_MAX_NOTIFY_COUNTERS with a local variable that we can set to something that comes out of the info object. That makes it easier going forward. Also, store the value of that variable in the module (see my earlier comment).

for (i = 0 ; i < comm_size ; ++i) {
total += ompi_osc_ucx_get_size(module, i);
/* each rank's slot holds its window data plus its notify counters */
total += ompi_osc_ucx_get_size(module, i) + notify_size;
}

module->segment_base = NULL;
Expand Down Expand Up @@ -849,14 +858,16 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt
goto error;
}


/* Each rank's window slot is (peer_size + notify_size) bytes; the
* notify counters for rank i are at shmem_addrs[i] + peer_size. */
for (i = 0, total = 0; i < comm_size ; ++i) {
size_t peer_size = ompi_osc_ucx_get_size(module, i);
if (peer_size || !module->noncontig_shared_win) {
module->shmem_addrs[i] = ((uint64_t) module->segment_base) + total;
total += peer_size;
total += peer_size + notify_size;
} else {
module->shmem_addrs[i] = (uint64_t)NULL;
total += notify_size;
}
}

Expand Down Expand Up @@ -884,7 +895,16 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt
ret = OMPI_ERR_BAD_PARAM;
goto error;
}
ret = opal_common_ucx_wpmem_create(module->ctx, mem_base, module->size,
/* Append notify counters after the window data in the same registered
* memory region. For ALLOCATE flavor the UCX allocator will hand back
* a buffer of this extended size; for CREATE/SHARED the user buffer is
* large enough to hold only the window data, but we still register the
* extra bytes so that remote atomic operations on the counters can use
* the same rkey as the window data. */
size_t notify_reg_size = (flavor == MPI_WIN_FLAVOR_DYNAMIC) ? 0 :
OMPI_OSC_UCX_MAX_NOTIFY_COUNTERS * sizeof(uint64_t);
ret = opal_common_ucx_wpmem_create(module->ctx, mem_base,
module->size + notify_reg_size,
mem_type, &exchange_len_info,
OPAL_COMMON_UCX_WPMEM_ADDR_EXCHANGE_FULL,
(void *)module->comm,
Expand Down Expand Up @@ -957,6 +977,12 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt
module->state.acc_lock = TARGET_LOCK_UNLOCKED;
module->state.dynamic_lock = TARGET_LOCK_UNLOCKED;
module->state.dynamic_win_count = 0;

/* initialize notify counters to zero; they live at base + size */
if (flavor != MPI_WIN_FLAVOR_DYNAMIC && *base != NULL) {
memset((char *)*base + module->size, 0,
OMPI_OSC_UCX_MAX_NOTIFY_COUNTERS * sizeof(uint64_t));
}
for (i = 0; i < OMPI_OSC_UCX_ATTACH_MAX; i++) {
module->local_dynamic_win_info[i].refcnt = 0;
}
Expand Down
Loading