Skip to content

Commit

Permalink
ENH: mods to EnsemblGffRecord
Browse files Browse the repository at this point in the history
[NEW] .spans_array() returns the spans attribute as a 32-bit
    numpy array

[CHANGED] .update_record() (from update_from_attrs()) now updates
    start / stop values, also added a flag attribute to only do this
    once

[NEW] added _array_int32 function for converting span data to numpy array

[NEW] .to_record() method is responsible for preparing a record
    for addition into a database. Has optional control over whether
    to convert the spans numpy array to a BLOB for db storage. Note
    that this not used for sqlite3 since cogent3 defines a custom array
    field that handles that.
  • Loading branch information
GavinHuttley committed Jul 29, 2024
1 parent 3994f70 commit d5b9df4
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/ensembl_lite/_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
try:
self._file = h5py.File(source, mode=self.mode, **h5_kwargs)
except OSError:
print(source)
print(f"{source=}")
raise

if "r" not in self.mode and "align_name" not in self._file.attrs:
Expand Down
126 changes: 106 additions & 20 deletions src/ensembl_lite/_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,32 @@ def tidy_gff3_stableids(attrs: str) -> str:
return _typed_id.sub(_lower_case_match, attrs)


@functools.singledispatch
def _array_int32(data: Any) -> numpy.ndarray:
"""coerce data to a 32-bit numpy array
Notes
-----
intended for use by the EnsemblGffRecord class in producing
the spans_array
"""
return numpy.array(data, dtype=numpy.int32)


@_array_int32.register
def _(data: numpy.ndarray) -> numpy.ndarray:
return data.astype(numpy.int32)


@_array_int32.register
def _(data: bytes) -> numpy.ndarray:
return elt_mixin.blob_to_array(data)


class EnsemblGffRecord(GffRecord):
__slots__ = GffRecord.__slots__ + ("feature_id",)
"""this is a mutable object!"""

__slots__ = GffRecord.__slots__ + ("feature_id", "_is_updated")

def __init__(self, feature_id: Optional[int] = None, **kwargs):
is_canonical = kwargs.pop("is_canonical", None)
Expand All @@ -74,6 +98,8 @@ def __init__(self, feature_id: Optional[int] = None, **kwargs):
if descr:
self.attrs = f"description={descr};" + (self.attrs or "")

self._is_updated: bool = False

def __hash__(self) -> int:
return hash(self.name)

Expand Down Expand Up @@ -101,14 +127,17 @@ def is_canonical(self) -> bool:
attrs = self.attrs or ""
return "Ensembl_canonical" in attrs

def update_from_attrs(self) -> None:
"""updates attributes from the attrs string
def update_record(self) -> None:
"""updates attributes
Notes
-----
also updates biotype from the prefix in the name
uses the attrs string, also updates biotype from
the prefix in the name and start / stop from the spans
"""
attrs = self.attrs
if self._is_updated:
return
attrs = self.attrs or ""
id_regex = _feature_id if "ID=" in attrs else _exon_id
attr = tidy_gff3_stableids(attrs)
if feature_id := id_regex.search(attr):
Expand All @@ -124,11 +153,60 @@ def update_from_attrs(self) -> None:
biotype = self.name.split(":")[0]
self.biotype = "mrna" if biotype == "transcript" else biotype

if self.spans is not None:
spans = self.spans_array()
self.start = int(spans.min())
self.stop = int(spans.max())

self.start = None if self.start is None else int(self.start)
self.stop = None if self.stop is None else int(self.stop)

self._is_updated = True

@property
def size(self) -> int:
"""the sum of span segments"""
return 0 if self.spans is None else sum(abs(s - e) for s, e in self.spans)

def spans_array(self) -> numpy.ndarray | None:
"""returns the spans as a 32-bit array"""
return None if self.spans is None else _array_int32(self.spans)

def to_record(
self,
*,
fields: list[str] | None = None,
exclude_null: bool = False,
array_to_blob: bool = False,
) -> dict[str, str | bytes | int | None]:
"""returns dict with values suitable for a database record
Parameters
----------
fields
names of attributes to include, by default all
exclude_null
excludes fields with None values
array_to_blob
converts numpy arrays to bytes, don't use for sqlite3 based
cogent3 databases since those declare a custom type for arrays
Notes
-----
spans are converted to a 32-bit numpy array
"""
record = {}
fields = fields or [s for s in self.__slots__ if s != "_is_updated"]
for field in fields:
value = self.spans_array() if field == "spans" else self[field]
if field == "spans" and array_to_blob and value is not None:
value = elt_mixin.array_to_blob(value)

if exclude_null and value is None:
continue
record[field] = value
return record


def custom_gff_parser(
path: elt_util.PathType, num_fake_ids: int
Expand All @@ -141,7 +219,7 @@ def custom_gff_parser(
gff3=gff3,
make_record=EnsemblGffRecord,
):
record.update_from_attrs()
record.update_record()
if not record.name:
record.name = f"unknown-{num_fake_ids}"
num_fake_ids += 1
Expand Down Expand Up @@ -182,7 +260,7 @@ class EnsemblGffDb(elt_mixin.SqliteDbMixin):
"description": "TEXT",
"attributes": "TEXT",
"comments": "TEXT",
"spans": "array", # aggregation of coords across records
"spans": "BLOB", # aggregation of coords across records
"stableid": "TEXT",
"id": "INTEGER PRIMARY KEY AUTOINCREMENT",
"is_canonical": "INTEGER",
Expand Down Expand Up @@ -301,14 +379,11 @@ def add_feature(
if feature is None:
feature = self._build_feature(kwargs)

feature.update_record() # custom_gff_parser already does this
id_cols = ("biotype_id", "id")
cols = [col for col in self._feature_schema if col not in id_cols]
# do conversion to numpy array after the above statement to avoid issue of
# having a numpy array in a conditional
feature.spans = numpy.array(feature.spans)
feature.start = feature.start or int(feature.spans.min())
feature.stop = feature.stop or int(feature.spans.max())
vals = [feature[col] for col in cols] + [self._get_biotype_id(feature.biotype)]
record = feature.to_record(fields=cols, array_to_blob=True)
vals = [record[col] for col in cols] + [self._get_biotype_id(feature.biotype)]
cols += ["biotype_id"]
placeholders = ",".join("?" * len(cols))
sql = f"INSERT INTO feature({','.join(cols)}) VALUES ({placeholders}) RETURNING id"
Expand Down Expand Up @@ -383,7 +458,9 @@ def get_features_matching(
table_name="gff", columns=columns, **query_args
):
result = dict(zip(columns, result))
result["spans"] = [tuple(c) for c in result["spans"]]
result["spans"] = [
tuple(c) for c in elt_mixin.blob_to_array(result["spans"])
]
yield result

def get_feature_children(
Expand All @@ -398,7 +475,9 @@ def get_feature_children(
table_name="parent_to_child", columns=cols, parent_stableid=name, **kwargs
):
result = dict(zip(cols, result))
result["spans"] = [tuple(c) for c in result["spans"]]
result["spans"] = [
tuple(c) for c in elt_mixin.blob_to_array(result["spans"])
]
results[result["name"]] = result
return list(results.values())

Expand All @@ -414,6 +493,9 @@ def get_feature_parent(
table_name="child_to_parent", columns=cols, child_stableid=name
):
result = dict(zip(cols, result))
result["spans"] = [
tuple(c) for c in elt_mixin.blob_to_array(result["spans"])
]
results[result["name"]] = result
return list(results.values())

Expand All @@ -433,11 +515,15 @@ def get_records_matching(
k: v for k, v in locals().items() if k not in ("self", "allow_partial")
}
sql, vals = _select_records_sql("gff", kwargs, allow_partial=allow_partial)
col_names = None
cols = None
for result in self._execute_sql(sql, values=vals):
if col_names is None:
col_names = result.keys()
yield {c: result[c] for c in col_names}
if cols is None:
cols = result.keys()
result = dict(zip(cols, result))
result["spans"] = [
tuple(c) for c in elt_mixin.blob_to_array(result["spans"])
]
yield result

def biotype_counts(self) -> dict[str, int]:
sql = "SELECT biotype, COUNT(*) as count FROM gff GROUP BY biotype"
Expand Down Expand Up @@ -529,7 +615,7 @@ def get_stableid_prefixes(records: typing.Sequence[EnsemblGffRecord]) -> set[str
"""returns the prefixes of the stableids"""
prefixes = set()
for record in records:
record.update_from_attrs()
record.update_record()
try:
prefix = elt_util.get_stableid_prefix(record.stableid)
except ValueError:
Expand Down
32 changes: 32 additions & 0 deletions src/ensembl_lite/_storage_mixin.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,45 @@
import contextlib
import dataclasses
import functools
import io
import os
import sqlite3

import numpy

from ensembl_lite import _util as elt_util

ReturnType = tuple[str, tuple] # the sql statement and corresponding values


@functools.singledispatch
def array_to_blob(data: numpy.ndarray) -> bytes:
with io.BytesIO() as out:
numpy.save(out, data)
out.seek(0)
output = out.read()
return output


@array_to_blob.register
def _(data: bytes) -> bytes:
# already a blob
return data


@functools.singledispatch
def blob_to_array(data: bytes) -> numpy.ndarray:
with io.BytesIO(data) as out:
out.seek(0)
result = numpy.load(out)
return result


@blob_to_array.register
def _(data: numpy.ndarray) -> numpy.ndarray:
return data


def _make_table_sql(
table_name: str,
columns: dict,
Expand Down
3 changes: 1 addition & 2 deletions tests/test_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ def test_select_alignment_minus_strand(start_end, namer):
spans=[(max(1, start or 0), min(end or 12, 12))],
)
expect = aln[ft.map.start : min(ft.map.end, 12)]

# mouse sequence is on minus strand, so need to adjust
# coordinates for query
s2 = aln.get_seq("s2")
Expand Down Expand Up @@ -305,7 +304,7 @@ def test_select_alignment_minus_strand(start_end, namer):
)
)
# drop the strand info
assert len(got) == 1
assert len(got) == 1, f"{s2_ft=}"
assert got[0].to_dict() == expect.to_dict()


Expand Down
25 changes: 25 additions & 0 deletions tests/test_dbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ensembl_lite import _align as elt_align
from ensembl_lite import _homology as elt_homology
from ensembl_lite import _maf as elt_maf
from ensembl_lite import _storage_mixin as elt_mixin


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -106,3 +107,27 @@ def test_pickling_db(db_align):
pkl = pickle.dumps(db_align) # nosec B301
upkl = pickle.loads(pkl) # nosec B301
assert db_align.source == upkl.source


@pytest.mark.parametrize(
"data", (numpy.array([], dtype=numpy.int32), numpy.array([0, 3], dtype=numpy.uint8))
)
def test_array_blob_roundtrip(data):
blob = elt_mixin.array_to_blob(data)
assert isinstance(blob, bytes)
inflated = elt_mixin.blob_to_array(blob)
assert numpy.array_equal(inflated, data)
assert inflated.dtype is data.dtype


@pytest.mark.parametrize(
"data",
(
numpy.array([0, 3], dtype=numpy.uint8),
elt_mixin.array_to_blob(numpy.array([0, 3], dtype=numpy.uint8)),
),
)
def test_blob_array(data):
# handles array or bytes as input
inflated = elt_mixin.blob_to_array(data)
assert numpy.array_equal(inflated, numpy.array([0, 3], dtype=numpy.uint8))
40 changes: 40 additions & 0 deletions tests/test_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from cogent3 import make_unaligned_seqs
from ensembl_lite import _genome as elt_genome
from ensembl_lite import _storage_mixin as elt_mixin
from numpy.testing import assert_allclose


Expand Down Expand Up @@ -379,6 +380,45 @@ def test_gff_record_hashing():
assert v == n


@pytest.mark.parametrize("exclude_null", (True, False))
def test_gff_record_to_record(exclude_null):
data = {
"seqid": "s1",
"name": "gene-01",
"biotype": "gene",
"spans": [(1, 3), (7, 9)],
"start": 1,
"stop": 9,
"strand": "+",
}
all_fields = (
{} if exclude_null else {s: None for s in elt_genome.EnsemblGffRecord.__slots__}
)
all_fields.pop("_is_updated", None)
record = elt_genome.EnsemblGffRecord(**data)
got = record.to_record(exclude_null=exclude_null)
expect = all_fields | data
expect.pop("spans")
got_spans = got.pop("spans")
assert got == expect
assert numpy.array_equal(elt_mixin.blob_to_array(got_spans), data["spans"])


@pytest.mark.parametrize("exclude_null", (True, False))
def test_gff_record_to_record_selected_fields(exclude_null):
data = {
"seqid": "s1",
"name": "gene-01",
"start": None,
"stop": None,
}
fields = list(data)
record = elt_genome.EnsemblGffRecord(**data)
got = record.to_record(fields=fields, exclude_null=exclude_null)
expect = {f: data[f] for f in fields if data[f] is not None or not exclude_null}
assert got == expect


@pytest.fixture
def ensembl_gff_records(DATA_DIR):
records, _ = elt_genome.custom_gff_parser(
Expand Down

0 comments on commit d5b9df4

Please sign in to comment.