Skip to content

Commit

Permalink
Release 0.1.3 (#17)
Browse files Browse the repository at this point in the history
* [pre-commit.ci] pre-commit autoupdate (#7)

* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/astral-sh/ruff-pre-commit: v0.7.3 → v0.9.6](astral-sh/ruff-pre-commit@v0.7.3...v0.9.6)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Add mean activation (#12)

* Add mean activation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Raise error on interpretability of cont. cov. (#13)

* Raise error on interpretability of cont. cov.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Improve error communication when Merlin not installed (#14)

* Improve error communication when Merlin not installed

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add merlin for doc creation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Bump version (#15)

* bump version 0.1.3 (#16)

* Bump version

* update CHANGELOG.md

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
moinfar and pre-commit-ci[bot] authored Feb 12, 2025
1 parent fdc2eb6 commit c77d3ee
Show file tree
Hide file tree
Showing 14 changed files with 99 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
hooks:
- id: prettier
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.3
rev: v0.9.6
hooks:
- id: ruff
types_or: [python, pyi, jupyter]
Expand Down
12 changes: 9 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@

-

## [0.1.3] - 2025-02-12

- Introduce mean activation to make non-negative latents possible (docs will come later)
- Better communication when Merlin is not installed
- Raise error when interpretability is called on model with continues covariates

## [0.1.2] - 2024-11-11

- No change in DRVI code
- Fix github workflow, tests, docs, and pypi publishing pipelines
- No change in DRVI code
- Fix github workflow, tests, docs, and pypi publishing pipelines

## [0.1.0] - 2024-08-21

- Moved all files from repo to scverse cookiecutter project template
- Moved all files from repo to scverse cookiecutter project template
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ Unsupervised Deep Disentangled Representation of Single-Cell Omics

Please refer to the [documentation][link-docs]. In particular, the

- [Tutorials][link-tutorials], specially
- [A demo](https://drvi.readthedocs.io/latest/notebooks/general_pipeline.html) of how to train DRVI and interpret the latent dimensions.
- [API documentation][link-api], specially
- [DRVI Model](https://drvi.readthedocs.io/latest/api/generated/drvi.model.DRVI.html)
- [DRVI utility functions (tools)](https://drvi.readthedocs.io/latest/api/tools.html)
- [DRVI plotting functions](https://drvi.readthedocs.io/latest/api/plotting.html)
- [Tutorials][link-tutorials], specially
- [A demo](https://drvi.readthedocs.io/latest/notebooks/general_pipeline.html) of how to train DRVI and interpret the latent dimensions.
- [API documentation][link-api], specially
- [DRVI Model](https://drvi.readthedocs.io/latest/api/generated/drvi.model.DRVI.html)
- [DRVI utility functions (tools)](https://drvi.readthedocs.io/latest/api/tools.html)
- [DRVI plotting functions](https://drvi.readthedocs.io/latest/api/plotting.html)

## System requirements

Expand Down
18 changes: 9 additions & 9 deletions docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ Specify `vX.X.X` as a tag name and create a release. For more information, see [

Please write documentation for new or changed features and use-cases. This project uses [sphinx][] with the following features:

- the [myst][] extension allows to write documentation in markdown/Markedly Structured Text
- [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension).
- Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
- [Sphinx autodoc typehints][], to automatically reference annotated input and output types
- Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)
- the [myst][] extension allows to write documentation in markdown/Markedly Structured Text
- [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension).
- Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
- [Sphinx autodoc typehints][], to automatically reference annotated input and output types
- Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)

See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information
on how to write documentation.
Expand All @@ -121,10 +121,10 @@ repository.

#### Hints

- If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only
if you do so can sphinx automatically create a link to the external documentation.
- If building the documentation fails because of a missing link that is outside your control, you can add an entry to
the `nitpick_ignore` list in `docs/conf.py`
- If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only
if you do so can sphinx automatically create a link to the external documentation.
- If building the documentation fails because of a missing link that is outside your control, you can add an entry to
the `nitpick_ignore` list in `docs/conf.py`

#### Building the docs locally

Expand Down
2 changes: 1 addition & 1 deletion docs/extensions/typed_returns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def _process_return(lines: Iterable[str]) -> Generator[str, None, None]:
for line in lines:
if m := re.fullmatch(r"(?P<param>\w+)\s+:\s+(?P<type>[\w.]+)", line):
yield f'-{m["param"]} (:class:`~{m["type"]}`)'
yield f"-{m['param']} (:class:`~{m['type']}`)"
else:
yield line

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires = ["hatchling"]

[project]
name = "drvi-py"
version = "0.1.2"
version = "0.1.3"
description = "Disentangled Generative Representation of Single Cell Omics"
readme = "README.md"
requires-python = ">=3.10,<3.13"
Expand Down Expand Up @@ -69,7 +69,7 @@ dev = [
]
doc = [
# Disable for now as nvidia servers return 404
# "merlin-dataloader==23.8.0",
"merlin-dataloader==23.8.0",
"docutils>=0.8,!=0.18.*,!=0.19.*",
"sphinx>=4",
"sphinx-book-theme>=1.0.0",
Expand Down
6 changes: 2 additions & 4 deletions src/drvi/nn_modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def __repr__(self):
if self._freeze_hook is None:
return f"Emb({self.num_embeddings}, {self.embedding_dim})"
else:
return (
f"Emb({self.num_embeddings}, {self.embedding_dim} | " f"freeze: {self.n_freeze_x}, {self.n_freeze_y})"
)
return f"Emb({self.num_embeddings}, {self.embedding_dim} | freeze: {self.n_freeze_x}, {self.n_freeze_y})"


class MultiEmbedding(nn.Module):
Expand Down Expand Up @@ -103,7 +101,7 @@ def from_pretrained(cls, feature_embedding_instance):
def load_weights_from_trained_module(self, other, freeze_old=False):
assert len(self.emb_list) >= len(other.emb_list)
if len(self.emb_list) > len(other.emb_list):
logging.warning(f"Extending feature embedding {other} to {self} " f"with more feature categories.")
logging.warning(f"Extending feature embedding {other} to {self} with more feature categories.")
else:
logging.info(f"Extending feature embedding {other} to {self}")
for self_emb, other_emb in zip(self.emb_list, other.emb_list, strict=False):
Expand Down
4 changes: 1 addition & 3 deletions src/drvi/nn_modules/feature_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ def __init__(self, feature_info_str_list: list[str], axis="var", total_dim=None,
self.axis = axis
if any(fi.dim is None for fi in self.feature_info_list):
if total_dim is None and default_dim is None:
raise ValueError(
f"missing dim in {feature_info_str_list}\n" f"Please provide `total_dim` or `default_dim`"
)
raise ValueError(f"missing dim in {feature_info_str_list}\nPlease provide `total_dim` or `default_dim`")
if total_dim is not None:
self._fill_with_total_dim(total_dim)
if default_dim is not None:
Expand Down
31 changes: 23 additions & 8 deletions src/drvi/scvi_tools_based/merlin_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
import importlib
import logging

logger = logging.getLogger(__name__)

def get_placeholder(name, error_message, allow_init=False):
error_message = error_message + f" Cannot use '{name}'."

class ClassLevelGetAttrMeta(type):
def __getattr__(cls, name):
raise ImportError(error_message)

class LazyNonExistingModulePlaceholder(metaclass=ClassLevelGetAttrMeta):
def __init__(self):
if not allow_init:
raise ImportError(error_message)
super().__init__()

return LazyNonExistingModulePlaceholder


if importlib.util.find_spec("merlin"):
from . import fields
Expand All @@ -10,12 +24,13 @@
from ._data_manager import MerlinDataManager
from ._data_splitter import MerlinDataSplitter
else:
fields = None
MerlinData = None
MerlinTransformedDataLoader = None
MerlinDataManager = None
MerlinDataSplitter = None
logger.warning("Merlin is not installed. To use merline dataloader please install it.")
error_msg = "Merlin is not installed. To use merline dataloader please install it."
fields = get_placeholder("fields", error_msg)
MerlinData = get_placeholder("MerlinData", error_msg)
MerlinTransformedDataLoader = get_placeholder("MerlinTransformedDataLoader", error_msg)
MerlinDataManager = get_placeholder("MerlinDataManager", error_msg)
MerlinDataSplitter = get_placeholder("MerlinDataSplitter", error_msg)


__all__ = [
"MerlinData",
Expand Down
2 changes: 1 addition & 1 deletion src/drvi/scvi_tools_based/model/base/_archesmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def load_query_data(
raise ValueError("It appears you are loading a model from a different class.")

if _SETUP_ARGS_KEY not in registry:
raise ValueError("Saved model does not contain original setup inputs. " "Cannot load the original setup.")
raise ValueError("Saved model does not contain original setup inputs. Cannot load the original setup.")

cls.setup_anndata(
adata,
Expand Down
11 changes: 8 additions & 3 deletions src/drvi/scvi_tools_based/module/_drvi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Iterable, Sequence
from typing import Literal

import numpy as np
Expand Down Expand Up @@ -83,7 +83,10 @@ class DRVIModule(BaseModuleClass):
prior_init_dataloader
Dataloader constructed to initialize the prior (or maintain in vamp).
var_activation
Callable used to ensure positivity of the variational distributions' variance.
The activation function to ensure positivity of the variatinal distribution. Defaults to "exp".
mean_activation
The activation function at the end of mean encoder. Defaults to "identity".
Possible values are "identity", "relu", "leaky_relu", "leaky_relu_{slope}", "elu", "elu_{min_vaule}".
encoder_layer_factory
A layer Factory instance for build encoder layers
decoder_layer_factory
Expand Down Expand Up @@ -145,7 +148,8 @@ def __init__(
] = "pnb_softmax",
prior: Literal["normal", "gmm_x", "vamp_x"] = "normal",
prior_init_dataloader: DataLoader | None = None,
var_activation: Callable | Literal["exp", "pow2"] = "exp",
var_activation: Literal["exp", "pow2"] = "exp",
mean_activation: str = "identity",
encoder_layer_factory: LayerFactory = None,
decoder_layer_factory: LayerFactory = None,
extra_encoder_kwargs: dict | None = None,
Expand Down Expand Up @@ -201,6 +205,7 @@ def __init__(
affine_batch_norm=affine_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
var_activation=var_activation,
mean_activation=mean_activation,
layer_factory=encoder_layer_factory,
covariate_modeling_strategy=covariate_modeling_strategy,
categorical_covariate_dims=categorical_covariate_dims if self.encode_covariates else [],
Expand Down
33 changes: 26 additions & 7 deletions src/drvi/scvi_tools_based/nn/_base_components.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import collections
import math
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Iterable, Sequence
from typing import Literal

import torch
Expand Down Expand Up @@ -390,8 +390,10 @@ class Encoder(nn.Module):
Minimum value for the variance;
used for numerical stability
var_activation
Callable used to ensure positivity of the variance.
Defaults to :meth:`torch.exp`.
The activation function to ensure positivity of the variance. Defaults to "exp".
mean_activation
The activation function at the end of mean encoder. Defaults to "identity".
Possible values are "identity", "relu", "leaky_relu", "leaky_relu_{slope}", "elu", "elu_{min_vaule}".
layer_factory
A layer Factory instance for building layers
layers_location
Expand Down Expand Up @@ -419,7 +421,8 @@ def __init__(
dropout_rate: float = 0.1,
distribution: str = "normal",
var_eps: float = 1e-4,
var_activation: Callable | Literal["exp", "pow2"] = "exp",
var_activation: Literal["exp", "pow2"] = "exp",
mean_activation: str = "identity",
layer_factory: LayerFactory = None,
covariate_modeling_strategy: Literal[
"one_hot",
Expand Down Expand Up @@ -497,8 +500,24 @@ def __init__(
elif var_activation == "pow2":
self.var_activation = lambda x: torch.pow(x, 2)
else:
assert callable(var_activation)
self.var_activation = var_activation
raise NotImplementedError()

if mean_activation == "identity":
self.mean_activation = nn.Identity()
elif mean_activation == "relu":
self.mean_activation = nn.ReLU()
elif mean_activation.startswith("leaky_relu"):
if mean_activation == "leaky_relu":
mean_activation = "leaky_relu_0.01"
slope = float(mean_activation.split("leaky_relu_")[1])
self.mean_activation = nn.LeakyReLU(negative_slope=slope)
elif mean_activation.startswith("elu"):
if mean_activation == "elu":
mean_activation = "elu_1.0"
alpha = float(mean_activation.split("elu_")[1])
self.mean_activation = nn.ELU(alpha=alpha)
else:
raise NotImplementedError()

def forward(self, x: torch.Tensor, cat_full_tensor: torch.Tensor, cont_full_tensor: torch.Tensor = None):
r"""The forward computation for a single sample.
Expand All @@ -524,7 +543,7 @@ def forward(self, x: torch.Tensor, cat_full_tensor: torch.Tensor, cont_full_tens
x = torch.cat((x, cont_full_tensor), dim=-1)
# Parameters for latent distribution
q = self.encoder(self.input_dropout(x), cat_full_tensor) if self.encoder is not None else x
q_m = self.mean_encoder(q, cat_full_tensor)
q_m = self.mean_activation(self.mean_encoder(q, cat_full_tensor))
q_v = self.var_activation(self.var_encoder(q, cat_full_tensor)) + self.var_eps
dist = Normal(q_m, q_v.sqrt())
latent = self.z_transformation(dist.rsample())
Expand Down
3 changes: 3 additions & 0 deletions src/drvi/utils/tools/interpretability/_latent_traverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def make_traverse_adata(
else:
cat_vector = None

if model.adata_manager.get_state_registry(scvi.REGISTRY_KEYS.CONT_COVS_KEY):
raise NotImplementedError("Interpretability of models with continuous covariates are not implemented yet.")

# lib size
lib_vector = np.ones(n_samples) * 1e4
lib_vector = lib_vector[span_adata.obs["sample_id"]]
Expand Down
7 changes: 7 additions & 0 deletions tests/drvi_model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ def test_simple_integration_latent_splitting(self):
adata, n_latent=32, n_split_latent=8, split_method="split", split_aggregation="max"
)

def test_simple_integration_mean_activation(self):
adata = self.make_test_adata()
self._general_integration_test(adata, n_latent=32, n_split_latent=32, mean_activation="identity")
self._general_integration_test(adata, n_latent=32, n_split_latent=32, mean_activation="relu")
self._general_integration_test(adata, n_latent=32, n_split_latent=32, mean_activation="leaky_relu_0.4")
self._general_integration_test(adata, n_latent=32, n_split_latent=32, mean_activation="elu_0.4")

def test_decoder_reusing(self):
adata = self.make_test_adata()
for reuse_strategy in ["nowhere"]:
Expand Down

0 comments on commit c77d3ee

Please sign in to comment.