Skip to content

Commit

Permalink
Extract strata aggregation into it's own function and use in `compute…
Browse files Browse the repository at this point in the history
…_freq_by_strata`
  • Loading branch information
jkgoodrich committed Jan 9, 2024
1 parent b9f65e1 commit affeb69
Showing 1 changed file with 179 additions and 66 deletions.
245 changes: 179 additions & 66 deletions gnomad/utils/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1861,25 +1861,19 @@ def generate_freq_group_membership_array(
return ht


def compute_freq_by_strata(
def agg_by_strata(
mt: hl.MatrixTable,
entry_agg_funcs: Optional[Dict[str, Tuple[Callable, Callable]]] = None,
select_fields: Optional[List[str]] = None,
group_membership_includes_raw_group: bool = True,
group_membership_ht: Optional[hl.Table] = None,
) -> hl.Table:
"""
Compute call statistics and, when passed, entry aggregation function(s) by strata.
The computed call statistics are AC, AF, AN, and homozygote_count. The entry
aggregation functions are applied to the MatrixTable entries and aggregated. The
MatrixTable must contain a 'group_membership' annotation (like the one added by
`generate_freq_group_membership_array`) that is a list of bools to aggregate the
columns by.
Get row expression for annotations of each entry aggregation function(s) by strata.
.. note::
This function is primarily used through `annotate_freq` but can be used
independently if desired. Please see the `annotate_freq` function for more
complete documentation.
The entry aggregation functions are applied to the MatrixTable entries and
aggregated. If no `group_membership_ht` (like the one returned by
`generate_freq_group_membership_array`) is supplied, `mt` must contain a
'group_membership' annotation that is a list of bools to aggregate the columns by.
:param mt: Input MatrixTable.
:param entry_agg_funcs: Optional dict of entry aggregation functions. When
Expand All @@ -1890,15 +1884,9 @@ def compute_freq_by_strata(
function.
:param select_fields: Optional list of row fields from `mt` to keep on the output
Table.
:param group_membership_includes_raw_group: Whether the 'group_membership'
annotation includes an entry for the 'raw' group, representing all samples. If
False, the 'raw' group is inserted as the second element in all added
annotations using the same 'group_membership', resulting
in array lengths of 'group_membership'+1. If True, the second element of each
added annotation is still the 'raw' group, but the group membership is
determined by the values in the second element of 'group_membership', and the
output annotations will be the same length as 'group_membership'. Default is
True.
:param group_membership_ht: Optional Table containing group membership annotations
to stratify the coverage stats by. If not provided, the 'group_membership'
annotation is expected to be present on `mt`.
:return: Table or MatrixTable with allele frequencies by strata.
"""
if entry_agg_funcs is None:
Expand All @@ -1907,79 +1895,204 @@ def compute_freq_by_strata(
select_fields = []

n_samples = mt.count_cols()
n_groups = len(mt.group_membership.take(1)[0])
ht = mt.localize_entries("entries", "cols")
ht = ht.annotate_globals(
indices_by_group=hl.range(n_groups).map(
lambda g_i: hl.range(n_samples).filter(
lambda s_i: ht.cols[s_i].group_membership[g_i]
global_expr = {}
if "adj_group" in mt.index_globals():
global_expr["adj_group"] = mt.index_globals().adj_group
logger.info("Using the 'adj_group' global annotation found on the input MT.")

if group_membership_ht is None and "group_membership" not in mt.col:
raise ValueError(
"The 'group_membership' annotation is not found in the input MatrixTable "
"and 'group_membership_ht' is not specified."
)
elif group_membership_ht is None:
logger.info(
"'group_membership_ht' is not specified, using sample stratification "
"indicated by the 'group_membership' annotation on mt."
)
n_groups = len(mt.group_membership.take(1)[0])
else:
logger.info(
"'group_membership_ht' is specified, using sample stratification indicated "
"by its 'group_membership' annotation."
)
n_groups = len(group_membership_ht.group_membership.take(1)[0])
mt = mt.annotate_cols(
group_membership=group_membership_ht[mt.col_key].group_membership
)
if "adj_group" not in global_expr:
if "adj_group" in group_membership_ht.index_globals():
global_expr["adj_group"] = mt.index_globals().adj_group
logger.info(
"Using the 'adj_group' global annotation on 'group_membership_ht'."
)
elif "freq_meta" in group_membership_ht.index_globals():
logger.info(
"The 'freq_meta' global annotation is found in "
"'group_membership_ht', using it to determine the adj filtered "
"stratification groups."
)
freq_meta = group_membership_ht.index_globals().freq_meta

global_expr["adj_group"] = freq_meta.map(
lambda x: x.get("group", "NA") == "adj"
)

if "adj_group" not in global_expr:
global_expr["adj_group"] = hl.range(n_groups).map(lambda x: False)

n_adj_group = hl.eval(hl.len(global_expr["adj_group"]))
if hl.eval(hl.len(global_expr["adj_group"])) != n_groups:
raise ValueError(
f"The number of elements in the 'adj_group' ({n_adj_group}) global "
"annotation does not match the number of elements in the "
f"'group_membership' annotation ({n_groups})!",
)

# Keep only the entries needed for the aggregation functions.
select_expr = {}
if hl.eval(hl.any(global_expr["adj_group"])):
select_expr["adj"] = mt.adj

select_expr.update(**{ann: f[0](mt) for ann, f in entry_agg_funcs.items()})
mt = mt.select_entries(**select_expr)

# Convert MT to HT with a row annotation that is an array of all samples entries
# for that variant.
ht = mt.localize_entries("entries", "cols")

# For each stratification group in group_membership, determine the indices of the
# samples that belong to that group.
global_expr["indices_by_group"] = hl.range(n_groups).map(
lambda g_i: hl.range(n_samples).filter(
lambda s_i: ht.cols[s_i].group_membership[g_i]
)
)
ht = ht.annotate_globals(**global_expr)

# Pull out each annotation that will be used in the array aggregation below as its
# own ArrayExpression. This is important to prevent memory issues when performing
# the below array aggregations.
ht = ht.select(
*select_fields,
adj_array=ht.entries.map(lambda e: e.adj),
gt_array=ht.entries.map(lambda e: e.GT),
**{
ann: hl.map(lambda e, s: f[0](e, s), ht.entries, ht.cols)
for ann, f in entry_agg_funcs.items()
},
ann: ht.entries.map(lambda e: e[ann])
for ann in select_fields + list(select_expr.keys())
}
)

def _agg_by_group(
ht: hl.Table, agg_func: Callable, ann_expr: hl.expr.ArrayExpression, *args
ht: hl.Table, agg_func: Callable, ann_expr: hl.expr.ArrayExpression
) -> hl.expr.ArrayExpression:
"""
Aggregate `agg_expr` by group using the `agg_func` function.
:param ht: Input Hail Table.
:param agg_func: Aggregation function to apply to `agg_expr`.
:param agg_expr: Expression to aggregate by group.
:param args: Additional arguments to pass to the `agg_func`.
:param agg_func: Aggregation function to apply to `ann_expr`.
:param ann_expr: Expression to aggregate by group.
:return: Aggregated array expression.
"""
adj_agg_expr = ht.indices_by_group.map(
lambda s_indices: s_indices.aggregate(
lambda i: hl.agg.filter(ht.adj_array[i], agg_func(ann_expr[i], *args))
)
)
# Create final agg list by inserting or changing the "raw" group,
# representing all samples, in the adj_agg_list.
raw_agg_expr = ann_expr.aggregate(lambda x: agg_func(x, *args))
if group_membership_includes_raw_group:
extend_idx = 2
else:
extend_idx = 1

adj_agg_expr = (
adj_agg_expr[:1].append(raw_agg_expr).extend(adj_agg_expr[extend_idx:])
return hl.map(
lambda s_indices, adj: s_indices.aggregate(
lambda i: hl.if_else(
adj,
hl.agg.filter(ht.adj[i], agg_func(ann_expr[i])),
agg_func(ann_expr[i]),
)
),
ht.indices_by_group,
ht.adj_group,
)

return adj_agg_expr

freq_expr = _agg_by_group(ht, hl.agg.call_stats, ht.gt_array, ht.alleles)

# Select non-ref allele (assumes bi-allelic).
freq_expr = freq_expr.map(
lambda cs: cs.annotate(
AC=cs.AC[1],
AF=cs.AF[1],
homozygote_count=cs.homozygote_count[1],
)
)
# Add annotations for any supplied entry transform and aggregation functions.
ht = ht.select(
*select_fields,
**{ann: _agg_by_group(ht, f[1], ht[ann]) for ann, f in entry_agg_funcs.items()},
freq=freq_expr,
)

return ht.drop("cols")


def compute_freq_by_strata(
mt: hl.MatrixTable,
entry_agg_funcs: Optional[Dict[str, Tuple[Callable, Callable]]] = None,
select_fields: Optional[List[str]] = None,
group_membership_includes_raw_group: bool = True,
) -> hl.Table:
"""
Compute call statistics and, when passed, entry aggregation function(s) by strata.
The computed call statistics are AC, AF, AN, and homozygote_count. The entry
aggregation functions are applied to the MatrixTable entries and aggregated. The
MatrixTable must contain a 'group_membership' annotation (like the one added by
`generate_freq_group_membership_array`) that is a list of bools to aggregate the
columns by.
.. note::
This function is primarily used through `annotate_freq` but can be used
independently if desired. Please see the `annotate_freq` function for more
complete documentation.
:param mt: Input MatrixTable.
:param entry_agg_funcs: Optional dict of entry aggregation functions. When
specified, additional annotations are added to the output Table/MatrixTable.
The keys of the dict are the names of the annotations and the values are tuples
of functions. The first function is used to transform the `mt` entries in some
way, and the second function is used to aggregate the output from the first
function.
:param select_fields: Optional list of row fields from `mt` to keep on the output
Table.
:param group_membership_includes_raw_group: Whether the 'group_membership'
annotation includes an entry for the 'raw' group, representing all samples. If
False, the 'raw' group is inserted as the second element in all added
annotations using the same 'group_membership', resulting
in array lengths of 'group_membership'+1. If True, the second element of each
added annotation is still the 'raw' group, but the group membership is
determined by the values in the second element of 'group_membership', and the
output annotations will be the same length as 'group_membership'. Default is
True.
:return: Table or MatrixTable with allele frequencies by strata.
"""
if not group_membership_includes_raw_group:
# Add the 'raw' group to the 'group_membership' annotation.
mt = mt.annotate_cols(
group_membership=hl.array([mt.group_membership[0]]).extend(
mt.group_membership
)
)

# Add adj_group global annotation indicating that the second element in
# group_membership is 'raw' and all others are 'adj'.
mt = mt.annotate_globals(
adj_group=hl.range(hl.len(mt.group_membership.take(1)[0])).map(lambda x: x != 1)
)

if entry_agg_funcs is None:
entry_agg_funcs = {}

def _get_freq_expr(gt_expr: hl.expr.CallExpression) -> hl.expr.StructExpression:
"""
Get struct expression with call statistics.
:param gt_expr: CallExpression to compute call statistics on.
:return: StructExpression with call statistics.
"""
# Get the source Table for the CallExpression to grab alleles.
ht = gt_expr._indices.source
freq_expr = hl.agg.call_stats(gt_expr, ht.alleles)
# Select non-ref allele (assumes bi-allelic).
freq_expr = freq_expr.annotate(
AC=freq_expr.AC[1],
AF=freq_expr.AF[1],
homozygote_count=freq_expr.homozygote_count[1],
)

return freq_expr

entry_agg_funcs["freq"] = (lambda x: x.GT, _get_freq_expr)

return agg_by_strata(mt, entry_agg_funcs, select_fields).drop("adj_group")


def update_structured_annotations(
ht: hl.Table,
annotation_update_exprs: Dict[str, hl.Expression],
Expand Down

0 comments on commit affeb69

Please sign in to comment.