Skip to content

Commit

Permalink
TL/SHARP: enable reduce-scatter only with SAT (#1084)
Browse files Browse the repository at this point in the history
  • Loading branch information
bureddy authored Mar 4, 2025
1 parent b7fbf66 commit 5e555d7
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/components/tl/sharp/tl_sharp.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ static ucc_config_field_t ucc_tl_sharp_context_config_table[] = {
ucc_offsetof(ucc_tl_sharp_context_config_t, use_rcache),
UCC_CONFIG_TYPE_BOOL},

{"REG_THRESH", "256",
{"REG_THRESH", "0",
"Size threshold to register buffers",
ucc_offsetof(ucc_tl_sharp_context_config_t, reg_threshold),
UCC_CONFIG_TYPE_MEMUNITS},
Expand Down
1 change: 1 addition & 0 deletions src/components/tl/sharp/tl_sharp.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ typedef struct ucc_tl_sharp_context {
ucc_mpool_t req_mp;
ucc_tl_sharp_oob_ctx_t oob_ctx;
ucc_rcache_t *rcache;
struct sharp_coll_caps sharp_caps;
} ucc_tl_sharp_context_t;
UCC_CLASS_DECLARE(ucc_tl_sharp_context_t, const ucc_base_context_params_t *,
const ucc_base_config_t *);
Expand Down
8 changes: 8 additions & 0 deletions src/components/tl/sharp/tl_sharp_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,10 @@ ucc_status_t ucc_tl_sharp_reduce_scatter_init(ucc_tl_sharp_task_t *task)
{
ucc_coll_args_t *args = &TASK_ARGS(task);

if (!(TASK_CTX(task)->sharp_caps.support_mask.feature_mask & SHARP_FEATURE_SAT)) {
return UCC_ERR_NOT_SUPPORTED;
}

if (!ucc_coll_args_is_predefined_dt(args, UCC_RANK_INVALID)) {
return UCC_ERR_NOT_SUPPORTED;
}
Expand Down Expand Up @@ -556,6 +560,10 @@ ucc_status_t ucc_tl_sharp_allgather_init(ucc_tl_sharp_task_t *task)
{
ucc_coll_args_t *args = &TASK_ARGS(task);

if (!(TASK_CTX(task)->sharp_caps.support_mask.feature_mask & SHARP_FEATURE_SAT)) {
return UCC_ERR_NOT_SUPPORTED;
}

if (!ucc_coll_args_is_predefined_dt(args, UCC_RANK_INVALID)) {
return UCC_ERR_NOT_SUPPORTED;
}
Expand Down
8 changes: 8 additions & 0 deletions src/components/tl/sharp/tl_sharp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,14 @@ ucc_status_t ucc_tl_sharp_context_init(ucc_tl_sharp_context_t *sharp_ctx,
return UCC_ERR_NO_RESOURCE;
}

ret = sharp_coll_caps_query(*context, &sharp_ctx->sharp_caps);
if (ret < 0) {
tl_error(sharp_ctx->super.super.lib, "sharp_coll_caps_query failed: %s(%d)",
sharp_coll_strerror(ret), ret);
sharp_coll_finalize(*context);
return UCC_ERR_NO_RESOURCE;
}

return UCC_OK;
}

Expand Down
14 changes: 3 additions & 11 deletions src/components/tl/sharp/tl_sharp_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -99,24 +99,16 @@ UCC_CLASS_INIT_FUNC(ucc_tl_sharp_team_t, ucc_base_context_t *tl_context,
SHARP_DTYPE_UNKNOWN) ||
(ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(SHARP_DTYPE_BFLOAT16)] ==
SHARP_DTYPE_UNKNOWN)) {
struct sharp_coll_caps sharp_caps;
ret = sharp_coll_caps_query(sharp_ctx, &sharp_caps);
if (ret < 0) {
status = sharp_status_to_ucc_status(ret);
tl_error(ctx->super.super.lib, "sharp_coll_caps_query failed: %s(%d)",
sharp_coll_strerror(ret), ret);
goto cleanup;
}

if (sharp_caps.support_mask.dtypes & UCC_BIT(SHARP_DTYPE_INT8)) {
if (ctx->sharp_caps.support_mask.dtypes & UCC_BIT(SHARP_DTYPE_INT8)) {
tl_debug(ctx->super.super.lib, "enabling support for UCC_DT_INT8");
ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(UCC_DT_INT8)] = SHARP_DTYPE_INT8;
} else {
tl_debug(ctx->super.super.lib, "disabling support for UCC_DT_INT8");
ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(UCC_DT_INT8)] = SHARP_DTYPE_NULL;
}

if (sharp_caps.support_mask.dtypes & UCC_BIT(SHARP_DTYPE_UINT8)) {
if (ctx->sharp_caps.support_mask.dtypes & UCC_BIT(SHARP_DTYPE_UINT8)) {
tl_debug(ctx->super.super.lib, "enabling support for UCC_DT_UINT8");
ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(UCC_DT_UINT8)] = SHARP_DTYPE_UINT8;
} else {
Expand All @@ -125,7 +117,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_sharp_team_t, ucc_base_context_t *tl_context,
}


if (sharp_caps.support_mask.dtypes & UCC_BIT(SHARP_DTYPE_BFLOAT16)) {
if (ctx->sharp_caps.support_mask.dtypes & UCC_BIT(SHARP_DTYPE_BFLOAT16)) {
tl_debug(ctx->super.super.lib, "enabling support for UCC_DT_BFLOAT16");
ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(UCC_DT_BFLOAT16)] = SHARP_DTYPE_BFLOAT16;
} else {
Expand Down

0 comments on commit 5e555d7

Please sign in to comment.