From d5b9df46c12d824d575b4e38a1065c854e428145 Mon Sep 17 00:00:00 2001 From: Gavin Huttley Date: Mon, 29 Jul 2024 14:30:38 +1000 Subject: [PATCH] ENH: mods to EnsemblGffRecord [NEW] .spans_array() returns the spans attribute as a 32-bit numpy array [CHANGED] .update_record() (from update_from_attrs()) now updates start / stop values, also added a flag attribute to only do this once [NEW] added _array_int32 function for converting span data to numpy array [NEW] .to_record() method is responsible for preparing a record for addition into a database. Has optional control over whether to convert the spans numpy array to a BLOB for db storage. Note that this not used for sqlite3 since cogent3 defines a custom array field that handles that. --- src/ensembl_lite/_align.py | 2 +- src/ensembl_lite/_genome.py | 126 ++++++++++++++++++++++++----- src/ensembl_lite/_storage_mixin.py | 32 ++++++++ tests/test_align.py | 3 +- tests/test_dbs.py | 25 ++++++ tests/test_genome.py | 40 +++++++++ 6 files changed, 205 insertions(+), 23 deletions(-) diff --git a/src/ensembl_lite/_align.py b/src/ensembl_lite/_align.py index 6bf909b..977b5a2 100644 --- a/src/ensembl_lite/_align.py +++ b/src/ensembl_lite/_align.py @@ -98,7 +98,7 @@ def __init__( try: self._file = h5py.File(source, mode=self.mode, **h5_kwargs) except OSError: - print(source) + print(f"{source=}") raise if "r" not in self.mode and "align_name" not in self._file.attrs: diff --git a/src/ensembl_lite/_genome.py b/src/ensembl_lite/_genome.py index 04815d7..491cade 100644 --- a/src/ensembl_lite/_genome.py +++ b/src/ensembl_lite/_genome.py @@ -56,8 +56,32 @@ def tidy_gff3_stableids(attrs: str) -> str: return _typed_id.sub(_lower_case_match, attrs) +@functools.singledispatch +def _array_int32(data: Any) -> numpy.ndarray: + """coerce data to a 32-bit numpy array + + Notes + ----- + intended for use by the EnsemblGffRecord class in producing + the spans_array + """ + return numpy.array(data, dtype=numpy.int32) + + +@_array_int32.register +def _(data: numpy.ndarray) -> numpy.ndarray: + return data.astype(numpy.int32) + + +@_array_int32.register +def _(data: bytes) -> numpy.ndarray: + return elt_mixin.blob_to_array(data) + + class EnsemblGffRecord(GffRecord): - __slots__ = GffRecord.__slots__ + ("feature_id",) + """this is a mutable object!""" + + __slots__ = GffRecord.__slots__ + ("feature_id", "_is_updated") def __init__(self, feature_id: Optional[int] = None, **kwargs): is_canonical = kwargs.pop("is_canonical", None) @@ -74,6 +98,8 @@ def __init__(self, feature_id: Optional[int] = None, **kwargs): if descr: self.attrs = f"description={descr};" + (self.attrs or "") + self._is_updated: bool = False + def __hash__(self) -> int: return hash(self.name) @@ -101,14 +127,17 @@ def is_canonical(self) -> bool: attrs = self.attrs or "" return "Ensembl_canonical" in attrs - def update_from_attrs(self) -> None: - """updates attributes from the attrs string + def update_record(self) -> None: + """updates attributes Notes ----- - also updates biotype from the prefix in the name + uses the attrs string, also updates biotype from + the prefix in the name and start / stop from the spans """ - attrs = self.attrs + if self._is_updated: + return + attrs = self.attrs or "" id_regex = _feature_id if "ID=" in attrs else _exon_id attr = tidy_gff3_stableids(attrs) if feature_id := id_regex.search(attr): @@ -124,11 +153,60 @@ def update_from_attrs(self) -> None: biotype = self.name.split(":")[0] self.biotype = "mrna" if biotype == "transcript" else biotype + if self.spans is not None: + spans = self.spans_array() + self.start = int(spans.min()) + self.stop = int(spans.max()) + + self.start = None if self.start is None else int(self.start) + self.stop = None if self.stop is None else int(self.stop) + + self._is_updated = True + @property def size(self) -> int: """the sum of span segments""" return 0 if self.spans is None else sum(abs(s - e) for s, e in self.spans) + def spans_array(self) -> numpy.ndarray | None: + """returns the spans as a 32-bit array""" + return None if self.spans is None else _array_int32(self.spans) + + def to_record( + self, + *, + fields: list[str] | None = None, + exclude_null: bool = False, + array_to_blob: bool = False, + ) -> dict[str, str | bytes | int | None]: + """returns dict with values suitable for a database record + + Parameters + ---------- + fields + names of attributes to include, by default all + exclude_null + excludes fields with None values + array_to_blob + converts numpy arrays to bytes, don't use for sqlite3 based + cogent3 databases since those declare a custom type for arrays + + Notes + ----- + spans are converted to a 32-bit numpy array + """ + record = {} + fields = fields or [s for s in self.__slots__ if s != "_is_updated"] + for field in fields: + value = self.spans_array() if field == "spans" else self[field] + if field == "spans" and array_to_blob and value is not None: + value = elt_mixin.array_to_blob(value) + + if exclude_null and value is None: + continue + record[field] = value + return record + def custom_gff_parser( path: elt_util.PathType, num_fake_ids: int @@ -141,7 +219,7 @@ def custom_gff_parser( gff3=gff3, make_record=EnsemblGffRecord, ): - record.update_from_attrs() + record.update_record() if not record.name: record.name = f"unknown-{num_fake_ids}" num_fake_ids += 1 @@ -182,7 +260,7 @@ class EnsemblGffDb(elt_mixin.SqliteDbMixin): "description": "TEXT", "attributes": "TEXT", "comments": "TEXT", - "spans": "array", # aggregation of coords across records + "spans": "BLOB", # aggregation of coords across records "stableid": "TEXT", "id": "INTEGER PRIMARY KEY AUTOINCREMENT", "is_canonical": "INTEGER", @@ -301,14 +379,11 @@ def add_feature( if feature is None: feature = self._build_feature(kwargs) + feature.update_record() # custom_gff_parser already does this id_cols = ("biotype_id", "id") cols = [col for col in self._feature_schema if col not in id_cols] - # do conversion to numpy array after the above statement to avoid issue of - # having a numpy array in a conditional - feature.spans = numpy.array(feature.spans) - feature.start = feature.start or int(feature.spans.min()) - feature.stop = feature.stop or int(feature.spans.max()) - vals = [feature[col] for col in cols] + [self._get_biotype_id(feature.biotype)] + record = feature.to_record(fields=cols, array_to_blob=True) + vals = [record[col] for col in cols] + [self._get_biotype_id(feature.biotype)] cols += ["biotype_id"] placeholders = ",".join("?" * len(cols)) sql = f"INSERT INTO feature({','.join(cols)}) VALUES ({placeholders}) RETURNING id" @@ -383,7 +458,9 @@ def get_features_matching( table_name="gff", columns=columns, **query_args ): result = dict(zip(columns, result)) - result["spans"] = [tuple(c) for c in result["spans"]] + result["spans"] = [ + tuple(c) for c in elt_mixin.blob_to_array(result["spans"]) + ] yield result def get_feature_children( @@ -398,7 +475,9 @@ def get_feature_children( table_name="parent_to_child", columns=cols, parent_stableid=name, **kwargs ): result = dict(zip(cols, result)) - result["spans"] = [tuple(c) for c in result["spans"]] + result["spans"] = [ + tuple(c) for c in elt_mixin.blob_to_array(result["spans"]) + ] results[result["name"]] = result return list(results.values()) @@ -414,6 +493,9 @@ def get_feature_parent( table_name="child_to_parent", columns=cols, child_stableid=name ): result = dict(zip(cols, result)) + result["spans"] = [ + tuple(c) for c in elt_mixin.blob_to_array(result["spans"]) + ] results[result["name"]] = result return list(results.values()) @@ -433,11 +515,15 @@ def get_records_matching( k: v for k, v in locals().items() if k not in ("self", "allow_partial") } sql, vals = _select_records_sql("gff", kwargs, allow_partial=allow_partial) - col_names = None + cols = None for result in self._execute_sql(sql, values=vals): - if col_names is None: - col_names = result.keys() - yield {c: result[c] for c in col_names} + if cols is None: + cols = result.keys() + result = dict(zip(cols, result)) + result["spans"] = [ + tuple(c) for c in elt_mixin.blob_to_array(result["spans"]) + ] + yield result def biotype_counts(self) -> dict[str, int]: sql = "SELECT biotype, COUNT(*) as count FROM gff GROUP BY biotype" @@ -529,7 +615,7 @@ def get_stableid_prefixes(records: typing.Sequence[EnsemblGffRecord]) -> set[str """returns the prefixes of the stableids""" prefixes = set() for record in records: - record.update_from_attrs() + record.update_record() try: prefix = elt_util.get_stableid_prefix(record.stableid) except ValueError: diff --git a/src/ensembl_lite/_storage_mixin.py b/src/ensembl_lite/_storage_mixin.py index 6a968a5..ddbeb97 100644 --- a/src/ensembl_lite/_storage_mixin.py +++ b/src/ensembl_lite/_storage_mixin.py @@ -1,13 +1,45 @@ import contextlib import dataclasses +import functools +import io import os import sqlite3 +import numpy + from ensembl_lite import _util as elt_util ReturnType = tuple[str, tuple] # the sql statement and corresponding values +@functools.singledispatch +def array_to_blob(data: numpy.ndarray) -> bytes: + with io.BytesIO() as out: + numpy.save(out, data) + out.seek(0) + output = out.read() + return output + + +@array_to_blob.register +def _(data: bytes) -> bytes: + # already a blob + return data + + +@functools.singledispatch +def blob_to_array(data: bytes) -> numpy.ndarray: + with io.BytesIO(data) as out: + out.seek(0) + result = numpy.load(out) + return result + + +@blob_to_array.register +def _(data: numpy.ndarray) -> numpy.ndarray: + return data + + def _make_table_sql( table_name: str, columns: dict, diff --git a/tests/test_align.py b/tests/test_align.py index 18cedd1..5319380 100644 --- a/tests/test_align.py +++ b/tests/test_align.py @@ -274,7 +274,6 @@ def test_select_alignment_minus_strand(start_end, namer): spans=[(max(1, start or 0), min(end or 12, 12))], ) expect = aln[ft.map.start : min(ft.map.end, 12)] - # mouse sequence is on minus strand, so need to adjust # coordinates for query s2 = aln.get_seq("s2") @@ -305,7 +304,7 @@ def test_select_alignment_minus_strand(start_end, namer): ) ) # drop the strand info - assert len(got) == 1 + assert len(got) == 1, f"{s2_ft=}" assert got[0].to_dict() == expect.to_dict() diff --git a/tests/test_dbs.py b/tests/test_dbs.py index 6fd7cbe..20e5026 100644 --- a/tests/test_dbs.py +++ b/tests/test_dbs.py @@ -6,6 +6,7 @@ from ensembl_lite import _align as elt_align from ensembl_lite import _homology as elt_homology from ensembl_lite import _maf as elt_maf +from ensembl_lite import _storage_mixin as elt_mixin @pytest.fixture(scope="function") @@ -106,3 +107,27 @@ def test_pickling_db(db_align): pkl = pickle.dumps(db_align) # nosec B301 upkl = pickle.loads(pkl) # nosec B301 assert db_align.source == upkl.source + + +@pytest.mark.parametrize( + "data", (numpy.array([], dtype=numpy.int32), numpy.array([0, 3], dtype=numpy.uint8)) +) +def test_array_blob_roundtrip(data): + blob = elt_mixin.array_to_blob(data) + assert isinstance(blob, bytes) + inflated = elt_mixin.blob_to_array(blob) + assert numpy.array_equal(inflated, data) + assert inflated.dtype is data.dtype + + +@pytest.mark.parametrize( + "data", + ( + numpy.array([0, 3], dtype=numpy.uint8), + elt_mixin.array_to_blob(numpy.array([0, 3], dtype=numpy.uint8)), + ), +) +def test_blob_array(data): + # handles array or bytes as input + inflated = elt_mixin.blob_to_array(data) + assert numpy.array_equal(inflated, numpy.array([0, 3], dtype=numpy.uint8)) diff --git a/tests/test_genome.py b/tests/test_genome.py index f0b252d..0aa83a7 100644 --- a/tests/test_genome.py +++ b/tests/test_genome.py @@ -4,6 +4,7 @@ import pytest from cogent3 import make_unaligned_seqs from ensembl_lite import _genome as elt_genome +from ensembl_lite import _storage_mixin as elt_mixin from numpy.testing import assert_allclose @@ -379,6 +380,45 @@ def test_gff_record_hashing(): assert v == n +@pytest.mark.parametrize("exclude_null", (True, False)) +def test_gff_record_to_record(exclude_null): + data = { + "seqid": "s1", + "name": "gene-01", + "biotype": "gene", + "spans": [(1, 3), (7, 9)], + "start": 1, + "stop": 9, + "strand": "+", + } + all_fields = ( + {} if exclude_null else {s: None for s in elt_genome.EnsemblGffRecord.__slots__} + ) + all_fields.pop("_is_updated", None) + record = elt_genome.EnsemblGffRecord(**data) + got = record.to_record(exclude_null=exclude_null) + expect = all_fields | data + expect.pop("spans") + got_spans = got.pop("spans") + assert got == expect + assert numpy.array_equal(elt_mixin.blob_to_array(got_spans), data["spans"]) + + +@pytest.mark.parametrize("exclude_null", (True, False)) +def test_gff_record_to_record_selected_fields(exclude_null): + data = { + "seqid": "s1", + "name": "gene-01", + "start": None, + "stop": None, + } + fields = list(data) + record = elt_genome.EnsemblGffRecord(**data) + got = record.to_record(fields=fields, exclude_null=exclude_null) + expect = {f: data[f] for f in fields if data[f] is not None or not exclude_null} + assert got == expect + + @pytest.fixture def ensembl_gff_records(DATA_DIR): records, _ = elt_genome.custom_gff_parser(