diff --git a/src/components/tl/ucp/allgather/allgather_knomial.c b/src/components/tl/ucp/allgather/allgather_knomial.c index 3b78c387cd..1ee1673844 100644 --- a/src/components/tl/ucp/allgather/allgather_knomial.c +++ b/src/components/tl/ucp/allgather/allgather_knomial.c @@ -319,15 +319,9 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init(ucc_base_coll_args_t *coll_args, ucc_memory_type_t mtype = GET_MT(&coll_args->args); size_t count = GET_TOTAL_COUNT(&coll_args->args, tsize); ucc_datatype_t dtype = GET_DT(&coll_args->args); - ucc_kn_radix_t radix, cfg_radix, opt_radix; + ucc_kn_radix_t radix; - opt_radix = (mtype == UCC_MEMORY_TYPE_HOST) ? tl_team->opt_radix_host : - tl_team->opt_radix; - - cfg_radix = ucc_tl_ucp_get_radix_from_range(tl_team, - count * ucc_dt_size(dtype), - mtype, p, opt_radix); - radix = ucc_min(cfg_radix, tsize); + radix = ucc_tl_ucp_get_knomial_radix(tl_team, count, dtype, mtype, p, 0); return ucc_tl_ucp_allgather_knomial_init_r(coll_args, team, task_h, radix); } diff --git a/src/components/tl/ucp/allreduce/allreduce_sra_knomial.c b/src/components/tl/ucp/allreduce/allreduce_sra_knomial.c index c20918bd28..31cef191a5 100644 --- a/src/components/tl/ucp/allreduce/allreduce_sra_knomial.c +++ b/src/components/tl/ucp/allreduce/allreduce_sra_knomial.c @@ -97,7 +97,7 @@ ucc_tl_ucp_allreduce_sra_knomial_frag_init(ucc_base_coll_args_t *coll_args, ucc_schedule_t *schedule; ucc_coll_task_t *task, *rs_task; ucc_status_t status; - ucc_kn_radix_t radix, cfg_radix, opt_radix; + ucc_kn_radix_t radix; size_t count; status = ucc_tl_ucp_get_schedule(tl_team, coll_args, @@ -112,14 +112,7 @@ ucc_tl_ucp_allreduce_sra_knomial_frag_init(ucc_base_coll_args_t *coll_args, count = coll_args->args.dst.info.count; } - opt_radix = (mem_type == UCC_MEMORY_TYPE_HOST) ? tl_team->opt_radix_host : - tl_team->opt_radix; - cfg_radix = ucc_tl_ucp_get_radix_from_range(tl_team, - count * ucc_dt_size(dtype), - mem_type, p, opt_radix); - radix = ucc_knomial_pattern_get_min_radix(cfg_radix, - UCC_TL_TEAM_SIZE(tl_team), - count); + radix = ucc_tl_ucp_get_knomial_radix(tl_team, count, dtype, mem_type, p, 1); /* 1st step of allreduce: knomial reduce_scatter */ UCC_CHECK_GOTO( ucc_tl_ucp_reduce_scatter_knomial_init_r(&args, team, &task, radix), diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index 2769244d39..eab127a606 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -508,4 +508,33 @@ ucc_tl_ucp_get_radix_from_range(ucc_tl_ucp_team_t *team, } return radix; } + +/* + * Get the radix for knomial patterns. + * If need_scratch is true, the radix is the minimum radix that can be used to fit into scratch buffer. + * Otherwise, the radix is the minimum radix that can be used to fit into team size. + */ +static inline unsigned ucc_tl_ucp_get_knomial_radix(ucc_tl_ucp_team_t *team, + size_t count, + ucc_datatype_t dtype, + ucc_memory_type_t mem_type, + ucc_mrange_uint_t *p, + int need_scratch) +{ + size_t msgsize = count * ucc_dt_size(dtype); + unsigned opt_radix, cfg_radix, radix; + + opt_radix = (mem_type == UCC_MEMORY_TYPE_HOST) ? team->opt_radix_host : + team->opt_radix; + cfg_radix = ucc_tl_ucp_get_radix_from_range(team, msgsize, mem_type, p, + opt_radix); + if (need_scratch) { + radix = ucc_knomial_pattern_get_min_radix(cfg_radix, UCC_TL_TEAM_SIZE(team), count); + } else { + radix = ucc_min(cfg_radix, UCC_TL_TEAM_SIZE(team)); + + } + return radix; +} + #endif