Skip to content

Commit

Permalink
[query] Add query_matrix_table an analogue to query_table (#14806)
Browse files Browse the repository at this point in the history
CHANGELOG: Add query_matrix_table an analogue to query_table

Part of #14499.

## Security Assessment

Delete all except the correct answer:
- This change has no security impact

### Impact Description
Increases the query python API surface, but is only a query change.
  • Loading branch information
chrisvittal authored Mar 4, 2025
1 parent 51ca2b1 commit fc85a38
Show file tree
Hide file tree
Showing 8 changed files with 316 additions and 0 deletions.
10 changes: 10 additions & 0 deletions hail/hail/src/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,7 @@ package defs {
classOf[PartitionNativeReader],
classOf[PartitionNativeReaderIndexed],
classOf[PartitionNativeIntervalReader],
classOf[PartitionZippedNativeIntervalReader],
classOf[PartitionZippedNativeReader],
classOf[PartitionZippedIndexedNativeReader],
classOf[BgenPartitionReader],
Expand Down Expand Up @@ -1216,6 +1217,15 @@ package defs {
spec,
(jv \ "uidFieldName").extract[String],
)
case "PartitionZippedNativeIntervalReader" =>
val path = (jv \ "path").extract[String]
val spec = RelationalSpec.read(ctx.fs, path).asInstanceOf[AbstractMatrixTableSpec]
PartitionZippedNativeIntervalReader(
ctx.stateManager,
path,
spec,
(jv \ "uidFieldName").extract[String],
)
case "GVCFPartitionReader" =>
val header = VCFHeaderInfo.fromJSON((jv \ "header"))
val callFields = (jv \ "callFields").extract[Set[String]]
Expand Down
56 changes: 56 additions & 0 deletions hail/hail/src/is/hail/expr/ir/TableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,7 @@ case class PartitionNativeIntervalReader(
lazy val partitioner = rowsSpec.partitioner(sm)

lazy val contextType: Type = RVDPartitioner.intervalIRRepresentation(partitioner.kType)
require(partitioner.kType.size > 0)

def toJValue: JValue = Extraction.decompose(this)(PartitionReader.formats)

Expand Down Expand Up @@ -1495,6 +1496,61 @@ case class PartitionZippedNativeReader(left: PartitionReader, right: PartitionRe
}
}

case class PartitionZippedNativeIntervalReader(
sm: HailStateManager,
mtPath: String,
mtSpec: AbstractMatrixTableSpec,
uidFieldName: String,
) extends PartitionReader {
require(mtSpec.indexed)

private[this] class PartitionEntriesNativeIntervalReader(
sm: HailStateManager,
entriesPath: String,
entriesSpec: AbstractTableSpec,
uidFieldName: String,
rowsTableSpec: AbstractTableSpec,
) extends PartitionNativeIntervalReader(sm, entriesPath, entriesSpec, uidFieldName) {
override lazy val partitioner = rowsTableSpec.rowsSpec.partitioner(sm)
}

// XXX: rows and entries paths are hardcoded, see MatrixTableSpec
private lazy val rowsReader =
PartitionNativeIntervalReader(sm, mtPath + "/rows", mtSpec.rowsSpec, "__dummy")

private lazy val entriesReader =
new PartitionEntriesNativeIntervalReader(
sm,
mtPath + "/entries",
mtSpec.entriesSpec,
uidFieldName,
rowsReader.tableSpec,
)

private lazy val zippedReader = PartitionZippedNativeReader(rowsReader, entriesReader)

def contextType = rowsReader.contextType
def fullRowType = zippedReader.fullRowType
def rowRequiredness(requestedType: TStruct): RStruct = zippedReader.rowRequiredness(requestedType)
def toJValue: JValue = Extraction.decompose(this)(PartitionReader.formats)

def emitStream(
ctx: ExecuteContext,
cb: EmitCodeBuilder,
mb: EmitMethodBuilder[_],
codeContext: EmitCode,
requestedType: TStruct,
): IEmitCode = {
val zipContextType: TBaseStruct = tcoerce(zippedReader.contextType)
val valueContext = cb.memoize(codeContext)
val contexts: IndexedSeq[EmitCode] = FastSeq(valueContext, valueContext)
val st = SStackStruct(zipContextType, contexts.map(_.emitType))
val context = EmitCode.present(mb, st.fromEmitCodes(cb, contexts))

zippedReader.emitStream(ctx, cb, mb, context, requestedType)
}
}

case class PartitionZippedIndexedNativeReader(
specLeft: AbstractTypedCodecSpec,
specRight: AbstractTypedCodecSpec,
Expand Down
2 changes: 2 additions & 0 deletions hail/python/hail/expr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@
qchisqtail,
qnorm,
qpois,
query_matrix_table_rows,
query_table,
rand_beta,
rand_bool,
Expand Down Expand Up @@ -554,6 +555,7 @@
'_console_log',
'dnorm',
'dchisq',
'query_matrix_table_rows',
'query_table',
'keyed_union',
'keyed_intersection',
Expand Down
64 changes: 64 additions & 0 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import itertools
import operator
import os.path
from typing import Any, Callable, Iterable, Optional, TypeVar, Union

import numpy as np
Expand Down Expand Up @@ -7071,6 +7072,69 @@ def query_table(path, point_or_interval):
)


@typecheck(path=builtins.str, point_or_interval=expr_any, entries_name=builtins.str)
def query_matrix_table_rows(path, point_or_interval, entries_name='entries_array'):
"""Query row records from a matrix table corresponding to a given point or
range of row keys. The entry fields are localized as an array of structs as
in :meth:`.MatrixTable.localize_entries`.
Notes
-----
This function does not dispatch to a distributed runtime; it can be used inside
already-distributed queries such as in :meth:`.Table.annotate`.
Warning
-------
This function contains no safeguards against reading large amounts of data
using a single thread.
Parameters
----------
path : :class:`str`
Table path.
point_or_interval
Point or interval to query.
entries_name : :class:`str`
Identifier to use for the localized entries array. Must not conflict
with any row field identifiers. Defaults to ``entries_array``.
Returns
-------
:class:`.ArrayExpression`
"""
matrix_table = hl.read_matrix_table(path)
if entries_name in matrix_table.row:
raise ValueError(
f'field "{entries_name}" is present in matrix table row fields, use a different `entries_name`'
)
entries_table = hl.read_table(os.path.join(path, 'entries'))
[entry_id] = list(entries_table.row)

full_row_type = tstruct(**matrix_table.row.dtype, **entries_table.row.dtype)
key_typ = matrix_table.row_key.dtype

if point_or_interval.dtype != key_typ[0] and isinstance(point_or_interval.dtype, hl.tinterval):
partition_interval = hl.interval(
start=__validate_and_coerce_endpoint(point_or_interval.start, key_typ),
end=__validate_and_coerce_endpoint(point_or_interval.end, key_typ),
includes_start=point_or_interval.includes_start,
includes_end=point_or_interval.includes_end,
)
else:
point = __validate_and_coerce_endpoint(point_or_interval, key_typ)
partition_interval = hl.interval(start=point, end=point, includes_start=True, includes_end=True)
read_part_ir = ir.ReadPartition(
partition_interval._ir, reader=ir.PartitionZippedNativeIntervalReader(path, full_row_type)
)
stream_expr = construct_expr(
read_part_ir,
type=hl.tstream(full_row_type),
indices=partition_interval._indices,
aggregations=partition_interval._aggregations,
)
return stream_expr.map(lambda item: item.rename({entry_id: entries_name})).to_array()


@typecheck(msg=expr_str, result=expr_any)
def _console_log(msg, result):
indices, aggregations = unify_all(msg, result)
Expand Down
2 changes: 2 additions & 0 deletions hail/python/hail/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
NDArraySVD,
NDArrayWrite,
PartitionNativeIntervalReader,
PartitionZippedNativeIntervalReader,
ProjectedTopLevelReference,
ReadPartition,
Recur,
Expand Down Expand Up @@ -527,6 +528,7 @@
'TableNativeFanoutWriter',
'ReadPartition',
'PartitionNativeIntervalReader',
'PartitionZippedNativeIntervalReader',
'GVCFPartitionReader',
'TableGen',
'Partitioner',
Expand Down
31 changes: 31 additions & 0 deletions hail/python/hail/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3510,6 +3510,37 @@ def row_type(self):
return tstruct(**self.table_row_type, **{self.uid_field: ttuple(tint64, tint64)})


class PartitionZippedNativeIntervalReader(PartitionReader):
def __init__(self, path, full_row_type, uid_field=None):
self.path = path
self.full_row_type = full_row_type
self.uid_field = uid_field

def with_uid_field(self, uid_field):
return PartitionZippedNativeIntervalReader(path=self.path, uid_field=uid_field)

def render(self):
return escape_str(
json.dumps({
"name": "PartitionZippedNativeIntervalReader",
"path": self.path,
"uidFieldName": self.uid_field if self.uid_field is not None else '__dummy',
})
)

def _eq(self, other):
return (
isinstance(other, PartitionZippedNativeIntervalReader)
and self.path == other.path
and self.uid_field == other.uid_field
)

def row_type(self):
if self.uid_field is None:
return self.full_row_type
return tstruct(**self.full_row_type, **{self.uid_field: ttuple(tint64, tint64)})


class ReadPartition(IR):
@typecheck_method(context=IR, reader=PartitionReader)
def __init__(self, context, reader):
Expand Down
143 changes: 143 additions & 0 deletions hail/python/test/hail/matrixtable/test_matrix_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2383,3 +2383,146 @@ def test_struct_of_arrays_encoding():
etype = md['_codecSpec']['_eType']
assert 'EStructOfArrays' in etype
assert mt._same(std_mt)


@pytest.fixture(scope='module')
def query_mt_mt():
path = new_temp_file(extension='mt')
mt = hl.utils.range_matrix_table(n_rows=200, n_cols=100, n_partitions=10)
mt = mt.filter_rows(mt.row_idx % 10 == 0)
mt = mt.filter_cols(mt.col_idx % 10 == 0)
mt = mt.annotate_rows(s=hl.str(mt.row_idx))
mt = mt.annotate_entries(n=mt.row_idx * mt.col_idx)
mt.write(path)
return path


def test_query_matrix_table_rows_errors(query_mt_mt):
with pytest.raises(ValueError, match='field "s" is present'):
hl.query_matrix_table_rows(query_mt_mt, 0, 's')
with pytest.raises(ValueError, match='key mismatch: cannot use'):
hl.query_matrix_table_rows(query_mt_mt, hl.interval('1', '2'))
with pytest.raises(ValueError, match='key mismatch: cannot use'):
hl.query_matrix_table_rows(query_mt_mt, '1')
with pytest.raises(ValueError, match='query point value cannot be an empty struct'):
hl.query_matrix_table_rows(query_mt_mt, hl.struct())
with pytest.raises(ValueError, match='query point type has 2 field'):
hl.query_matrix_table_rows(query_mt_mt, hl.struct(idx=5, foo='s'))


def query_matrix_table_rows_test_parameters():
def ea_for(n):
return [hl.Struct(n=n * m) for m in range(0, 100, 10)]

return [
(50, [hl.Struct(row_idx=50, s='50', e=ea_for(50))]),
(hl.struct(idx=50), [hl.Struct(row_idx=50, s='50', e=ea_for(50))]),
(55, []),
(5, []),
(-1, []),
(205, []),
(
hl.interval(27, 66),
[
hl.Struct(row_idx=30, s='30', e=ea_for(30)),
hl.Struct(row_idx=40, s='40', e=ea_for(40)),
hl.Struct(row_idx=50, s='50', e=ea_for(50)),
hl.Struct(row_idx=60, s='60', e=ea_for(60)),
],
),
(hl.interval(276, 33333), []),
(hl.interval(-22276, -5), []),
(
hl.interval(hl.struct(row_idx=27), hl.struct(row_idx=66)),
[
hl.Struct(row_idx=30, s='30', e=ea_for(30)),
hl.Struct(row_idx=40, s='40', e=ea_for(40)),
hl.Struct(row_idx=50, s='50', e=ea_for(50)),
hl.Struct(row_idx=60, s='60', e=ea_for(60)),
],
),
(
hl.interval(40, 80, includes_end=True),
[
hl.Struct(row_idx=40, s='40', e=ea_for(40)),
hl.Struct(row_idx=50, s='50', e=ea_for(50)),
hl.Struct(row_idx=60, s='60', e=ea_for(60)),
hl.Struct(row_idx=70, s='70', e=ea_for(70)),
hl.Struct(row_idx=80, s='80', e=ea_for(80)),
],
),
]


@pytest.mark.parametrize("query,expected", query_matrix_table_rows_test_parameters())
def test_query_matrix_table_rows(query_mt_mt, query, expected):
assert hl.eval(hl.query_matrix_table_rows(query_mt_mt, query, 'e')) == expected


def test_query_matrix_table_rows_randomness(query_mt_mt):
i1 = hl.interval(27, 45)
i2 = hl.interval(45, 80, includes_end=True)
rows = hl.query_matrix_table_rows(query_mt_mt, i1, 'e').extend(hl.query_matrix_table_rows(query_mt_mt, i2, 'e'))
x = hl.eval(rows.aggregate(lambda _: hl.struct(r=hl.agg.collect_as_set(hl.rand_int64()), n=hl.agg.count())))
assert len(x.r) == x.n


@pytest.fixture(scope='module')
def query_mt_compound_key_mt():
path = new_temp_file(extension='mt')
mt = hl.utils.range_matrix_table(n_rows=200, n_cols=100, n_partitions=10)
mt = mt.filter_rows(mt.row_idx % 10 == 0)
mt = mt.filter_cols(mt.col_idx % 10 == 0)
mt = mt.annotate_rows(idx2=mt.row_idx % 20, s=hl.str(mt.row_idx))
mt = mt.annotate_entries(n=mt.row_idx * mt.col_idx)
mt = mt.key_rows_by('row_idx', 'idx2')
mt.write(path)
return path


def query_matrix_table_rows_compound_key_parameters():
def ea_for(n):
return [hl.Struct(n=n * m) for m in range(0, 100, 10)]

return [
(50, [hl.Struct(row_idx=50, idx2=10, s='50', e=ea_for(50))]),
(hl.struct(row_idx=50), [hl.Struct(row_idx=50, idx2=10, s='50', e=ea_for(50))]),
(hl.interval(hl.struct(row_idx=50, idx2=11), hl.struct(row_idx=60, idx2=-1)), []),
]


@pytest.mark.parametrize("query,expected", query_matrix_table_rows_compound_key_parameters())
def test_query_matrix_table_rows_compound_key(query_mt_compound_key_mt, query, expected):
assert hl.eval(hl.query_matrix_table_rows(query_mt_compound_key_mt, query, 'e')) == expected


@pytest.fixture(scope='module')
def query_mt_interval_key_mt():
path = new_temp_file(extension='mt')
mt = hl.utils.range_matrix_table(n_rows=200, n_cols=100, n_partitions=10)
mt = mt.filter_rows(mt.row_idx % 10 == 0)
mt = mt.filter_cols(mt.col_idx % 10 == 0)
mt = mt.annotate_entries(n=mt.row_idx * mt.col_idx)
mt = mt.key_rows_by(interval=hl.interval(mt.row_idx, mt.row_idx + 50))
mt.write(path)
return path


def query_matrix_table_rows_interval_key_parameters():
def ea_for(n):
return [hl.Struct(n=n * m) for m in range(0, 100, 10)]

return [
(hl.interval(20, 70), [hl.Struct(row_idx=20, interval=hl.Interval(20, 70), e=ea_for(20))]),
(hl.interval(20, 0), []),
(hl.struct(interval=hl.interval(20, 0)), []),
(
hl.interval(hl.interval(15, 10), hl.interval(20, 71)),
[hl.Struct(row_idx=20, interval=hl.Interval(20, 70), e=ea_for(20))],
),
]


@pytest.mark.parametrize("query,expected", query_matrix_table_rows_interval_key_parameters())
def test_query_matrix_table_rows_interval_key(query_mt_interval_key_mt, query, expected):
assert hl.eval(hl.query_matrix_table_rows(query_mt_interval_key_mt, query, 'e')) == expected
Loading

0 comments on commit fc85a38

Please sign in to comment.