Skip to content

Commit

Permalink
REVIEW: add common function for radix select
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Feb 28, 2025
1 parent 763b023 commit 4cc79d6
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 17 deletions.
10 changes: 2 additions & 8 deletions src/components/tl/ucp/allgather/allgather_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -319,15 +319,9 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_memory_type_t mtype = GET_MT(&coll_args->args);
size_t count = GET_TOTAL_COUNT(&coll_args->args, tsize);
ucc_datatype_t dtype = GET_DT(&coll_args->args);
ucc_kn_radix_t radix, cfg_radix, opt_radix;
ucc_kn_radix_t radix;

opt_radix = (mtype == UCC_MEMORY_TYPE_HOST) ? tl_team->opt_radix_host :
tl_team->opt_radix;

cfg_radix = ucc_tl_ucp_get_radix_from_range(tl_team,
count * ucc_dt_size(dtype),
mtype, p, opt_radix);
radix = ucc_min(cfg_radix, tsize);
radix = ucc_tl_ucp_get_knomial_radix(tl_team, count, dtype, mtype, p, 0);

return ucc_tl_ucp_allgather_knomial_init_r(coll_args, team, task_h, radix);
}
11 changes: 2 additions & 9 deletions src/components/tl/ucp/allreduce/allreduce_sra_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ ucc_tl_ucp_allreduce_sra_knomial_frag_init(ucc_base_coll_args_t *coll_args,
ucc_schedule_t *schedule;
ucc_coll_task_t *task, *rs_task;
ucc_status_t status;
ucc_kn_radix_t radix, cfg_radix, opt_radix;
ucc_kn_radix_t radix;
size_t count;

status = ucc_tl_ucp_get_schedule(tl_team, coll_args,
Expand All @@ -112,14 +112,7 @@ ucc_tl_ucp_allreduce_sra_knomial_frag_init(ucc_base_coll_args_t *coll_args,
count = coll_args->args.dst.info.count;
}

opt_radix = (mem_type == UCC_MEMORY_TYPE_HOST) ? tl_team->opt_radix_host :
tl_team->opt_radix;
cfg_radix = ucc_tl_ucp_get_radix_from_range(tl_team,
count * ucc_dt_size(dtype),
mem_type, p, opt_radix);
radix = ucc_knomial_pattern_get_min_radix(cfg_radix,
UCC_TL_TEAM_SIZE(tl_team),
count);
radix = ucc_tl_ucp_get_knomial_radix(tl_team, count, dtype, mem_type, p, 1);
/* 1st step of allreduce: knomial reduce_scatter */
UCC_CHECK_GOTO(
ucc_tl_ucp_reduce_scatter_knomial_init_r(&args, team, &task, radix),
Expand Down
29 changes: 29 additions & 0 deletions src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,4 +508,33 @@ ucc_tl_ucp_get_radix_from_range(ucc_tl_ucp_team_t *team,
}
return radix;
}

/*
* Get the radix for knomial patterns.
* If need_scratch is true, the radix is the minimum radix that can be used to fit into scratch buffer.
* Otherwise, the radix is the minimum radix that can be used to fit into team size.
*/
static inline unsigned ucc_tl_ucp_get_knomial_radix(ucc_tl_ucp_team_t *team,
size_t count,
ucc_datatype_t dtype,
ucc_memory_type_t mem_type,
ucc_mrange_uint_t *p,
int need_scratch)
{
size_t msgsize = count * ucc_dt_size(dtype);
unsigned opt_radix, cfg_radix, radix;

opt_radix = (mem_type == UCC_MEMORY_TYPE_HOST) ? team->opt_radix_host :
team->opt_radix;
cfg_radix = ucc_tl_ucp_get_radix_from_range(team, msgsize, mem_type, p,
opt_radix);
if (need_scratch) {
radix = ucc_knomial_pattern_get_min_radix(cfg_radix, UCC_TL_TEAM_SIZE(team), count);
} else {
radix = ucc_min(cfg_radix, UCC_TL_TEAM_SIZE(team));

}
return radix;
}

#endif

0 comments on commit 4cc79d6

Please sign in to comment.