Skip to content

Commit

Permalink
TL/UCP: add radix selection to kn radix (#1072)
Browse files Browse the repository at this point in the history
* TL/UCP: add radix selection to kn radix

* REVIEW: add common function for radix select
  • Loading branch information
Sergei-Lebedev authored Feb 28, 2025
1 parent 8892c73 commit 9a4e209
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 27 deletions.
9 changes: 7 additions & 2 deletions src/components/tl/ucp/allgather/allgather_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,14 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_rank_t size = UCC_TL_TEAM_SIZE(tl_team);
ucc_mrange_uint_t *p = &tl_team->cfg.allgather_kn_radix;
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(tl_team);
ucc_memory_type_t mtype = GET_MT(&coll_args->args);
size_t count = GET_TOTAL_COUNT(&coll_args->args, tsize);
ucc_datatype_t dtype = GET_DT(&coll_args->args);
ucc_kn_radix_t radix;

radix = ucc_min(UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.allgather_kn_radix, size);
radix = ucc_tl_ucp_get_knomial_radix(tl_team, count, dtype, mtype, p, 0);

return ucc_tl_ucp_allgather_knomial_init_r(coll_args, team, task_h, radix);
}
10 changes: 2 additions & 8 deletions src/components/tl/ucp/allreduce/allreduce_sra_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ ucc_tl_ucp_allreduce_sra_knomial_frag_init(ucc_base_coll_args_t *coll_args,
ucc_schedule_t *schedule;
ucc_coll_task_t *task, *rs_task;
ucc_status_t status;
ucc_kn_radix_t radix, cfg_radix;
ucc_kn_radix_t radix;
size_t count;

status = ucc_tl_ucp_get_schedule(tl_team, coll_args,
Expand All @@ -112,13 +112,7 @@ ucc_tl_ucp_allreduce_sra_knomial_frag_init(ucc_base_coll_args_t *coll_args,
count = coll_args->args.dst.info.count;
}

cfg_radix = ucc_tl_ucp_get_radix_from_range(tl_team,
count * ucc_dt_size(dtype),
mem_type, p,
tl_team->opt_radix);
radix = ucc_knomial_pattern_get_min_radix(cfg_radix,
UCC_TL_TEAM_SIZE(tl_team),
count);
radix = ucc_tl_ucp_get_knomial_radix(tl_team, count, dtype, mem_type, p, 1);
/* 1st step of allreduce: knomial reduce_scatter */
UCC_CHECK_GOTO(
ucc_tl_ucp_reduce_scatter_knomial_init_r(&args, team, &task, radix),
Expand Down
8 changes: 5 additions & 3 deletions src/components/tl/ucp/bcast/bcast_sag_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,19 @@ ucc_tl_ucp_bcast_sag_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_schedule_t *schedule;
ucc_coll_task_t *task, *rs_task;
ucc_status_t status;
ucc_kn_radix_t radix, cfg_radix;
ucc_kn_radix_t radix, cfg_radix, opt_radix;

if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) {
/* ActiveSets currently are only supported with KN alg */
return ucc_tl_ucp_bcast_knomial_init(coll_args, team, task_h);
}

opt_radix = (mem_type == UCC_MEMORY_TYPE_HOST) ? tl_team->opt_radix_host :
tl_team->opt_radix;

cfg_radix = ucc_tl_ucp_get_radix_from_range(tl_team,
count * ucc_dt_size(dtype),
mem_type, p,
tl_team->opt_radix);
mem_type, p, opt_radix);
radix = ucc_knomial_pattern_get_min_radix(cfg_radix,
UCC_TL_TEAM_SIZE(tl_team),
count);
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/ucp/tl_ucp.c
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ ucc_config_field_t ucc_tl_ucp_lib_config_table[] = {
ucc_offsetof(ucc_tl_ucp_lib_config_t, reduce_scatter_kn_radix),
UCC_CONFIG_TYPE_UINT},

{"ALLGATHER_KN_RADIX", "4", "Radix of the knomial allgather algorithm",
{"ALLGATHER_KN_RADIX", "auto", "Radix of the knomial allgather algorithm",
ucc_offsetof(ucc_tl_ucp_lib_config_t, allgather_kn_radix),
UCC_CONFIG_TYPE_UINT},
UCC_CONFIG_TYPE_UINT_RANGED},

{"BCAST_KN_RADIX", "4", "Radix of the recursive-knomial bcast algorithm",
ucc_offsetof(ucc_tl_ucp_lib_config_t, bcast_kn_radix),
Expand Down
5 changes: 3 additions & 2 deletions src/components/tl/ucp/tl_ucp.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ typedef struct ucc_tl_ucp_lib_config {
ucc_mrange_uint_t allreduce_kn_radix;
ucc_mrange_uint_t allreduce_sra_kn_radix;
uint32_t reduce_scatter_kn_radix;
uint32_t allgather_kn_radix;
ucc_mrange_uint_t allgather_kn_radix;
uint32_t bcast_kn_radix;
ucc_mrange_uint_t bcast_sag_kn_radix;
uint32_t reduce_kn_radix;
Expand Down Expand Up @@ -145,7 +145,8 @@ typedef struct ucc_tl_ucp_team {
const char * tuning_str;
ucc_topo_t *topo;
ucc_ep_map_t ctx_map;
ucc_rank_t opt_radix;
ucc_rank_t opt_radix; /* generic opt radix */
ucc_rank_t opt_radix_host; /* host specific opt radix */
} ucc_tl_ucp_team_t;
UCC_CLASS_DECLARE(ucc_tl_ucp_team_t, ucc_base_context_t *,
const ucc_base_team_params_t *);
Expand Down
29 changes: 29 additions & 0 deletions src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,4 +508,33 @@ ucc_tl_ucp_get_radix_from_range(ucc_tl_ucp_team_t *team,
}
return radix;
}

/*
* Get the radix for knomial patterns.
* If need_scratch is true, the radix is the minimum radix that can be used to fit into scratch buffer.
* Otherwise, the radix is the minimum radix that can be used to fit into team size.
*/
static inline unsigned ucc_tl_ucp_get_knomial_radix(ucc_tl_ucp_team_t *team,
size_t count,
ucc_datatype_t dtype,
ucc_memory_type_t mem_type,
ucc_mrange_uint_t *p,
int need_scratch)
{
size_t msgsize = count * ucc_dt_size(dtype);
unsigned opt_radix, cfg_radix, radix;

opt_radix = (mem_type == UCC_MEMORY_TYPE_HOST) ? team->opt_radix_host :
team->opt_radix;
cfg_radix = ucc_tl_ucp_get_radix_from_range(team, msgsize, mem_type, p,
opt_radix);
if (need_scratch) {
radix = ucc_knomial_pattern_get_min_radix(cfg_radix, UCC_TL_TEAM_SIZE(team), count);
} else {
radix = ucc_min(cfg_radix, UCC_TL_TEAM_SIZE(team));

}
return radix;
}

#endif
1 change: 0 additions & 1 deletion src/components/tl/ucp/tl_ucp_lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_lib_t, const ucc_base_lib_params_t *params,
if (tl_ucp_config->kn_radix > 0) {
self->cfg.barrier_kn_radix = tl_ucp_config->kn_radix;
self->cfg.reduce_scatter_kn_radix = tl_ucp_config->kn_radix;
self->cfg.allgather_kn_radix = tl_ucp_config->kn_radix;
self->cfg.bcast_kn_radix = tl_ucp_config->kn_radix;
self->cfg.reduce_kn_radix = tl_ucp_config->kn_radix;
self->cfg.scatter_kn_radix = tl_ucp_config->kn_radix;
Expand Down
30 changes: 21 additions & 9 deletions src/components/tl/ucp/tl_ucp_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ static inline ucc_status_t ucc_tl_ucp_get_topo(ucc_tl_ucp_team_t *team)
UCC_CLASS_INIT_FUNC(ucc_tl_ucp_team_t, ucc_base_context_t *tl_context,
const ucc_base_team_params_t *params)
{
ucc_tl_ucp_context_t *ctx =
ucc_derived_of(tl_context, ucc_tl_ucp_context_t);
ucc_tl_ucp_context_t *ctx = ucc_derived_of(tl_context,
ucc_tl_ucp_context_t);
ucc_kn_radix_t max_radix, min_radix;
ucc_rank_t tsize;
ucc_rank_t tsize, max_ppn;
ucc_status_t status;

UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params);
Expand All @@ -59,6 +59,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_team_t, ucc_base_context_t *tl_context,
self->tuning_str = "";
self->topo = NULL;
self->opt_radix = UCC_UUNITS_AUTO_RADIX;
self->opt_radix_host = UCC_UUNITS_AUTO_RADIX;

status = ucc_config_clone_table(&UCC_TL_UCP_TEAM_LIB(self)->cfg, &self->cfg,
ucc_tl_ucp_lib_config_table);
Expand Down Expand Up @@ -91,14 +92,25 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_team_t, ucc_base_context_t *tl_context,
self->cfg.use_reordering = 0;
}

if (self->topo && !UCC_TL_IS_SERVICE_TEAM(self) &&
self->topo->topo->sock_bound) {
if (self->topo && !UCC_TL_IS_SERVICE_TEAM(self)) {
tsize = UCC_TL_TEAM_SIZE(self);
max_radix = (ucc_topo_max_ppn(self->topo) == 1) ? tsize :
ucc_min(tsize, ucc_topo_min_socket_size(self->topo));
min_radix = ucc_min(tsize, ucc_topo_max_ppn(self->topo) == 1 ? 3: 2);
max_ppn = ucc_topo_max_ppn(self->topo);

min_radix = ucc_min(tsize, 3);
max_radix = tsize;
self->opt_radix = ucc_kn_get_opt_radix(tsize, min_radix, max_radix);
tl_debug(tl_context->lib, "opt knomial radix: %d", self->opt_radix);
if (max_ppn == 1) {
self->opt_radix_host = self->opt_radix;
} else {
if (self->topo->topo->sock_bound) {
min_radix = 2;
max_radix = ucc_min(tsize, ucc_topo_min_socket_size(self->topo));
self->opt_radix_host = ucc_kn_get_opt_radix(tsize, min_radix,
max_radix);
}
}
tl_debug(tl_context->lib, "opt knomial radix: general %d host %d",
self->opt_radix, self->opt_radix_host);
}

tl_debug(tl_context->lib, "posted tl team: %p", self);
Expand Down

0 comments on commit 9a4e209

Please sign in to comment.