diff --git a/docs/developer.rst b/docs/developer.rst index 26466b03a..68e9b110d 100644 --- a/docs/developer.rst +++ b/docs/developer.rst @@ -12,8 +12,6 @@ Backends backends.ott.GWSolver backends.ott.OTTOutput backends.ott.GraphOTTOutput - backends.ott.GENOTLinSolver - backends.ott.output.OTTNeuralOutput backends.utils.get_solver backends.utils.get_available_backends @@ -46,7 +44,6 @@ Problems problems.BaseCompoundProblem problems.CompoundProblem cost.BaseCost - problems.CondOTProblem Mixins ^^^^^^ diff --git a/docs/user.rst b/docs/user.rst index ccc769697..2c8d19448 100644 --- a/docs/user.rst +++ b/docs/user.rst @@ -27,7 +27,6 @@ Generic Problems generic.SinkhornProblem generic.GWProblem generic.FGWProblem - generic.GENOTLinProblem Plotting ~~~~~~~~ diff --git a/src/moscot/_types.py b/src/moscot/_types.py index 6315bdb1a..c68610dab 100644 --- a/src/moscot/_types.py +++ b/src/moscot/_types.py @@ -2,19 +2,18 @@ from typing import Any, Literal, Mapping, Optional, Sequence, Union import numpy as np +from jax import Array as JaxArray +from numpy.typing import DTypeLike as DTypeLikeNumpy +from numpy.typing import NDArray from ott.initializers.linear.initializers import SinkhornInitializer from ott.initializers.linear.initializers_lr import LRInitializer from ott.initializers.quadratic.initializers import BaseQuadraticInitializer # TODO(michalk8): polish -try: - from numpy.typing import DTypeLike, NDArray - ArrayLike = NDArray[np.floating] -except (ImportError, TypeError): - ArrayLike = np.ndarray # type: ignore[misc] - DTypeLike = np.dtype # type: ignore[misc] +ArrayLike = Union[NDArray[np.floating], JaxArray] +DTypeLike = DTypeLikeNumpy ProblemKind_t = Literal["linear", "quadratic", "unknown"] Numeric_t = Union[int, float] # type of `time_key` arguments diff --git a/src/moscot/backends/ott/__init__.py b/src/moscot/backends/ott/__init__.py index 48ffdec64..7fdae526c 100644 --- a/src/moscot/backends/ott/__init__.py +++ b/src/moscot/backends/ott/__init__.py @@ -1,11 +1,19 @@ from ott.geometry import costs from moscot.backends.ott._utils import sinkhorn_divergence -from moscot.backends.ott.output import GraphOTTOutput, OTTNeuralOutput, OTTOutput +from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver from moscot.costs import register_cost -__all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "OTTNeuralOutput", "sinkhorn_divergence", "GENOTLinSolver"] +__all__ = [ + "OTTOutput", + "GWSolver", + "SinkhornSolver", + "NeuralOutput", + "sinkhorn_divergence", + "GENOTLinSolver", + "GraphOTTOutput", +] register_cost("euclidean", backend="ott")(costs.Euclidean) diff --git a/src/moscot/backends/ott/_utils.py b/src/moscot/backends/ott/_utils.py index 9f71e2d5a..58bc5b82a 100644 --- a/src/moscot/backends/ott/_utils.py +++ b/src/moscot/backends/ott/_utils.py @@ -184,7 +184,7 @@ def alpha_to_fused_penalty(alpha: float) -> float: return (1 - alpha) / alpha -def densify(arr: ArrayLike) -> jax.Array: +def densify(arr: Union[ArrayLike, sp.sparray, sp.spmatrix]) -> jax.Array: """If the input is sparse, convert it to dense. Parameters @@ -197,7 +197,8 @@ def densify(arr: ArrayLike) -> jax.Array: dense :mod:`jax` array. """ if sp.issparse(arr): - arr = arr.toarray() # type: ignore[attr-defined] + arr_sp: Union[sp.sparray, sp.spmatrix] = arr + arr = arr_sp.toarray() elif isinstance(arr, jesp.BCOO): arr = arr.todense() return jnp.asarray(arr) diff --git a/src/moscot/backends/ott/output.py b/src/moscot/backends/ott/output.py index 60d727faf..988004caa 100644 --- a/src/moscot/backends/ott/output.py +++ b/src/moscot/backends/ott/output.py @@ -17,7 +17,7 @@ from moscot.backends.ott._utils import get_nearest_neighbors from moscot.base.output import BaseDiscreteSolverOutput, BaseNeuralOutput -__all__ = ["OTTOutput", "GraphOTTOutput", "OTTNeuralOutput"] +__all__ = ["OTTOutput", "GraphOTTOutput", "NeuralOutput"] class OTTOutput(BaseDiscreteSolverOutput): @@ -182,6 +182,9 @@ def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike: axis=1 - forward, ).T # convert to batch first + def _apply_forward(self, x: ArrayLike) -> ArrayLike: + return self._apply(x, forward=True) + @property def shape(self) -> Tuple[int, int]: # noqa: D102 if isinstance(self._output, sinkhorn.SinkhornOutput): @@ -241,11 +244,11 @@ def _ones(self, n: int) -> ArrayLike: # noqa: D102 return jnp.ones((n,)) -class OTTNeuralOutput(BaseNeuralOutput): +class NeuralOutput(BaseNeuralOutput): """Output wrapper for GENOT.""" def __init__(self, model: GENOT, logs: dict[str, list[float]]): - """Initialize `OTTNeuralOutput`. + """Initialize `NeuralOutput`. Parameters ---------- @@ -269,8 +272,7 @@ def _project_transport_matrix( self, src_dist: ArrayLike, tgt_dist: ArrayLike, - forward: bool, - func: Callable[[jnp.ndarray], jnp.ndarray], + func: Callable[[ArrayLike], ArrayLike], save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments batch_size: int = 1024, k: int = 30, @@ -279,9 +281,9 @@ def _project_transport_matrix( recall_target: float = 0.95, aggregate_to_topk: bool = True, ) -> sp.csr_matrix: - row_indices: Union[jnp.ndarray, List[jnp.ndarray]] = [] - column_indices: Union[jnp.ndarray, List[jnp.ndarray]] = [] - distances_list: Union[jnp.ndarray, List[jnp.ndarray]] = [] + row_indices: List[ArrayLike] = [] + column_indices: List[ArrayLike] = [] + distances_list: List[ArrayLike] = [] if length_scale is None: key = jax.random.PRNGKey(seed) src_batch = src_dist[jax.random.choice(key, src_dist.shape[0], shape=((batch_size,)))] @@ -306,20 +308,14 @@ def _project_transport_matrix( row_indices = jnp.concatenate(row_indices) column_indices = jnp.concatenate(column_indices) tm = sp.csr_matrix((distances, (row_indices, column_indices)), shape=[len(src_dist), len(tgt_dist)]) - if forward: - if save_transport_matrix: - self._transport_matrix = tm - else: - tm = tm.T - if save_transport_matrix: - self._inverse_transport_matrix = tm + if save_transport_matrix: + self._transport_matrix = tm return tm def project_to_transport_matrix( # type:ignore[override] self, src_cells: ArrayLike, tgt_cells: ArrayLike, - forward: bool = True, condition: ArrayLike = None, save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments batch_size: int = 1024, @@ -351,7 +347,7 @@ def project_to_transport_matrix( # type:ignore[override] save_transport_matrix Whether to save the transport matrix. batch_size - Number of data points in the source distribution the neighborhoodgraph is computed + Number of data points in the source distribution the neighborhood graph is computed for in parallel. k Number of neighbors to construct the k-nearest neighbor graph of a mapped cell. @@ -375,13 +371,12 @@ def project_to_transport_matrix( # type:ignore[override] The projected transport matrix. """ src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells) - push = self.push if condition is None else lambda x: self.push(x, condition) - pull = self.pull if condition is None else lambda x: self.pull(x, condition) - func, src_dist, tgt_dist = (push, src_cells, tgt_cells) if forward else (pull, tgt_cells, src_cells) + conditioned_fn: Callable[[ArrayLike], ArrayLike] = lambda x: self.push(x, condition) + push = self.push if condition is None else conditioned_fn + func, src_dist, tgt_dist = (push, src_cells, tgt_cells) return self._project_transport_matrix( src_dist=src_dist, tgt_dist=tgt_dist, - forward=forward, func=func, save_transport_matrix=save_transport_matrix, # TODO(@MUCDK) adapt order of arguments batch_size=batch_size, @@ -406,31 +401,13 @@ def push(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike: ------- Pushed distribution. """ + if isinstance(x, (bool, int, float, complex)): + raise ValueError("Expected array, found scalar value.") if x.ndim not in (1, 2): raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.") - return self._apply(x, cond=cond, forward=True) - - def pull(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike: - """Pull distribution `x` conditioned on condition `cond`. - - This does not make sense for some neural models and is therefore left unimplemented. - - Parameters - ---------- - x - Distribution to push. - cond - Condition of conditional neural OT. - - Raises - ------ - NotImplementedError - """ - raise NotImplementedError("`pull` does not make sense for neural OT.") + return self._apply_forward(x, cond=cond) - def _apply(self, x: ArrayLike, forward: bool, cond: Optional[ArrayLike] = None) -> ArrayLike: - if not forward: - raise NotImplementedError("Backward i.e., pull on neural OT is not supported.") + def _apply_forward(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike: return self._model.transport(x, condition=cond) @property @@ -445,7 +422,7 @@ def shape(self) -> Tuple[int, int]: def to( self, device: Optional[Device_t] = None, - ) -> "OTTNeuralOutput": + ) -> "NeuralOutput": """Transfer the output to another device or change its data type. Parameters @@ -471,7 +448,7 @@ def to( # raise IndexError(f"Unable to fetch the device with `id={idx}`.") from err # out = jax.device_put(self._model, device) - # return OTTNeuralOutput(out) + # return NeuralOutput(out) return self # TODO(ilan-gold) move model to device @property diff --git a/src/moscot/backends/ott/solver.py b/src/moscot/backends/ott/solver.py index 361fe5d5a..50ca82fa4 100644 --- a/src/moscot/backends/ott/solver.py +++ b/src/moscot/backends/ott/solver.py @@ -53,7 +53,7 @@ densify, ensure_2d, ) -from moscot.backends.ott.output import GraphOTTOutput, OTTNeuralOutput, OTTOutput +from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput from moscot.base.problems._utils import TimeScalesHeatKernel from moscot.base.solver import OTSolver from moscot.costs import get_cost @@ -699,10 +699,10 @@ def solver(self) -> genot.GENOT: def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]: return {"batch_size", "train_size", "trainloader", "validloader", "seed"}, {} # type: ignore[return-value] - def _solve(self, data_samplers: Tuple[MultiLoader, MultiLoader]) -> OTTNeuralOutput: # type: ignore[override] + def _solve(self, data_samplers: Tuple[MultiLoader, MultiLoader]) -> NeuralOutput: # type: ignore[override] seed = self._neural_kwargs.get("seed", 0) # TODO(ilan-gold): unify rng hadnling like OTT tests rng = jax.random.PRNGKey(seed) logs = self.solver( data_samplers[0], n_iters=self._neural_kwargs.get("n_iters", 100), rng=rng ) # TODO(ilan-gold): validation and figure out defualts - return OTTNeuralOutput(self.solver, logs) + return NeuralOutput(self.solver, logs) diff --git a/src/moscot/backends/utils.py b/src/moscot/backends/utils.py index fde874c0f..988e05413 100644 --- a/src/moscot/backends/utils.py +++ b/src/moscot/backends/utils.py @@ -42,8 +42,7 @@ def register_solver( return _REGISTRY.register(backend) # type: ignore[return-value] -# TODO(@MUCDK) fix mypy error -@register_solver("ott") # type: ignore[arg-type] +@register_solver("ott") def _( problem_kind: Literal["linear", "quadratic"], solver_name: Optional[Literal["GENOTLinSolver"]] = None, diff --git a/src/moscot/base/cost.py b/src/moscot/base/cost.py index 8bea310be..a8e74b3e1 100644 --- a/src/moscot/base/cost.py +++ b/src/moscot/base/cost.py @@ -58,7 +58,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayLike: f"Cost matrix contains `{np.sum(np.isnan(cost))}` NaN values, " f"setting them to the maximum value `{maxx}`." ) - cost = np.nan_to_num(cost, nan=maxx) # type: ignore[call-overload] + cost = np.nan_to_num(cost, nan=maxx) # type: ignore[arg-type, type-var] if np.any(cost < 0): raise ValueError(f"Cost matrix contains `{np.sum(cost < 0)}` negative values.") return cost diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index 2570e2cb1..4e10203f3 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -11,7 +11,7 @@ from scipy.sparse.linalg import LinearOperator from moscot._logging import logger -from moscot._types import ArrayLike, Device_t, DTypeLike # type: ignore[attr-defined] +from moscot._types import ArrayLike, Device_t, DTypeLike __all__ = ["BaseDiscreteSolverOutput", "MatrixSolverOutput", "BaseNeuralOutput"] @@ -19,19 +19,24 @@ class BaseSolverOutput(abc.ABC): """Base class for all solver outputs.""" - @abc.abstractmethod - def pull(self, x: ArrayLike, **kwargs) -> ArrayLike: - """Pull the solution based on a condition.""" - @abc.abstractmethod def push(self, x: ArrayLike, **kwargs) -> ArrayLike: """Push the solution based on a condition.""" + @abc.abstractmethod + def _apply_forward(self, x: ArrayLike) -> ArrayLike: + """Apply the transport matrix in the forward direction.""" + @property @abc.abstractmethod def shape(self) -> tuple[int, int]: """Shape of the problem.""" + @property + @abc.abstractmethod + def converged(self) -> bool: + """Whether the solver converged.""" + @abc.abstractmethod def to(self: BaseSolverOutput, device: Optional[Device_t] = None) -> BaseSolverOutput: """Transfer self to another compute device. @@ -74,11 +79,6 @@ def transport_matrix(self) -> ArrayLike: def cost(self) -> float: """Regularized :term:`OT` cost.""" - @property - @abc.abstractmethod - def converged(self) -> bool: - """Whether the solver converged.""" - @property @abc.abstractmethod def potentials(self) -> Optional[tuple[ArrayLike, ArrayLike]]: @@ -348,6 +348,9 @@ def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike: return self.transport_matrix.T @ x return self.transport_matrix @ x + def _apply_forward(self, x: ArrayLike) -> ArrayLike: + return self._apply(x, forward=True) + @property def transport_matrix(self) -> ArrayLike: # noqa: D102 return self._transport_matrix @@ -393,7 +396,7 @@ def _ones(self, n: int) -> ArrayLike: return jnp.ones((n,), dtype=self.transport_matrix.dtype) -class BaseNeuralOutput(BaseDiscreteSolverOutput, abc.ABC): +class BaseNeuralOutput(BaseSolverOutput, abc.ABC): """Base class for output of.""" @abstractmethod @@ -402,7 +405,6 @@ def project_to_transport_matrix( source: Optional[ArrayLike] = None, target: Optional[ArrayLike] = None, condition: Optional[ArrayLike] = None, - forward: bool = True, save_transport_matrix: bool = False, batch_size: int = 1024, k: int = 30, @@ -410,19 +412,3 @@ def project_to_transport_matrix( seed: int = 42, ) -> sp.csr_matrix: """Project transport matrix.""" - pass - - @property - def transport_matrix(self): # noqa: D102 - raise NotImplementedError("Neural output does not require a transport matrix.") - - @property - def cost(self): # noqa: D102 - raise NotImplementedError("Neural output does not implement a cost property.") - - @property - def potentials(self): # noqa: D102 - raise NotImplementedError("Neural output does not need to implement a potentials property.") - - def _ones(self, n: int): # noqa: D102 - raise NotImplementedError("Neural output does not need to implement a `_ones` property.") diff --git a/src/moscot/base/problems/__init__.py b/src/moscot/base/problems/__init__.py index 544631d48..b554b3ce5 100644 --- a/src/moscot/base/problems/__init__.py +++ b/src/moscot/base/problems/__init__.py @@ -2,7 +2,7 @@ from moscot.base.problems.birth_death import BirthDeathMixin, BirthDeathProblem from moscot.base.problems.compound_problem import BaseCompoundProblem, CompoundProblem from moscot.base.problems.manager import ProblemManager -from moscot.base.problems.problem import BaseProblem, CondOTProblem, OTProblem +from moscot.base.problems.problem import BaseProblem, OTProblem __all__ = [ "AnalysisMixin", @@ -13,5 +13,4 @@ "ProblemManager", "BaseProblem", "OTProblem", - "CondOTProblem", ] diff --git a/src/moscot/base/problems/_utils.py b/src/moscot/base/problems/_utils.py index 40bf0a99a..482f79b75 100644 --- a/src/moscot/base/problems/_utils.py +++ b/src/moscot/base/problems/_utils.py @@ -386,7 +386,7 @@ def perm_test_extractor(res: Sequence[Tuple[ArrayLike, ArrayLike]]) -> Tuple[Arr corr_bs = np.concatenate(corr_bs, axis=0) corr_ci_low, corr_ci_high = np.quantile(corr_bs, q=ql, axis=0), np.quantile(corr_bs, q=qh, axis=0) - return pvals, corr_ci_low, corr_ci_high # type:ignore[return-value] + return pvals, corr_ci_low, corr_ci_high if not (0 <= confidence_level <= 1): raise ValueError(f"Expected `confidence_level` to be in interval `[0, 1]`, found `{confidence_level}`.") diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index 8cebbb639..811ac158c 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -7,12 +7,10 @@ Any, Dict, Hashable, - Iterable, List, Literal, Mapping, Optional, - Sequence, Tuple, TypeVar, Union, @@ -44,23 +42,11 @@ wrap_solve, ) from moscot.base.solver import OTSolver -from moscot.utils.subset_policy import ( # type:ignore[attr-defined] - ExplicitPolicy, - Policy_t, - StarPolicy, - SubsetPolicy, - create_policy, -) -from moscot.utils.tagged_array import ( - DistributionCollection, - DistributionContainer, - Tag, - TaggedArray, -) +from moscot.utils.tagged_array import Tag, TaggedArray K = TypeVar("K", bound=Hashable) -__all__ = ["BaseProblem", "OTProblem", "CondOTProblem"] +__all__ = ["BaseProblem", "OTProblem"] class CombinedMeta(abc.ABCMeta, NumpyDocstringInheritanceMeta): @@ -1119,209 +1105,3 @@ def __repr__(self) -> str: def __str__(self) -> str: return repr(self) - - -class CondOTProblem(BaseProblem): # TODO(@MUCDK) check generic types, save and load - """ - Base class for all conditional (nerual) optimal transport problems. - - Parameters - ---------- - adata - Source annotated data object. - kwargs - Keyword arguments for :class:`moscot.base.problems.problem.BaseProblem` - """ - - def __init__( - self, - adata: AnnData, - **kwargs: Any, - ): - super().__init__(**kwargs) - self._adata = adata - - self._distributions: Optional[DistributionCollection[K]] = None # type: ignore[valid-type] - self._policy: Optional[SubsetPolicy[Any]] = None - self._sample_pairs: Optional[List[Tuple[Any, Any]]] = None - - self._solver: Optional[OTSolver[BaseDiscreteSolverOutput]] = None - self._solution: Optional[BaseDiscreteSolverOutput] = None - - self._a: Optional[str] = None - self._b: Optional[str] = None - - @wrap_prepare - def prepare( - self, - policy_key: str, - policy: Policy_t, - xy: Mapping[str, Any], - xx: Mapping[str, Any], - conditions: Mapping[str, Any], - a: Optional[str] = None, - b: Optional[str] = None, - subset: Optional[Sequence[Tuple[K, K]]] = None, - reference: K = None, - **kwargs: Any, - ) -> "CondOTProblem": - """Prepare conditional optimal transport problem. - - Parameters - ---------- - xy - Geometry defining the linear term. If passed as a :class:`dict`, - :meth:`~moscot.utils.tagged_array.TaggedArray.from_adata` will be called. - policy - Policy defining which pairs of distributions to sample from during training. - policy_key - %(key)s - a - Source marginals. - b - Target marginals. - kwargs - Keyword arguments when creating the source/target marginals. - - - Returns - ------- - Self and modifies the following attributes: - TODO. - """ - self._problem_kind = "linear" - self._distributions = DistributionCollection() - self._solution = None - self._policy_key = policy_key - try: - self._distribution_id = pd.Series(self.adata.obs[policy_key]) - except KeyError: - raise KeyError(f"Unable to find data in `adata.obs[{policy_key!r}]`.") from None - - self._policy = create_policy(policy, adata=self.adata, key=policy_key) - if isinstance(self._policy, ExplicitPolicy): - self._policy = self._policy.create_graph(subset=subset) - elif isinstance(self._policy, StarPolicy): - self._policy = self._policy.create_graph(reference=reference) - else: - _ = self.policy.create_graph() # type: ignore[union-attr] - self._sample_pairs = list(self.policy._graph) # type: ignore[union-attr] - - for el in self.policy.categories: # type: ignore[union-attr] - adata_masked = self.adata[self._create_mask(el)] - a_created = self._create_marginals(adata_masked, data=a, source=True, **kwargs) - b_created = self._create_marginals(adata_masked, data=b, source=False, **kwargs) - self.distributions[el] = DistributionContainer.from_adata( # type: ignore[index] - adata_masked, a=a_created, b=b_created, **xy, **xx, **conditions - ) - return self - - @wrap_solve - def solve( - self, - backend: Literal["ott"] = "ott", - solver_name: Literal["GENOTLinSolver"] = "GENOTLinSolver", - device: Optional[Device_t] = None, - **kwargs: Any, - ) -> "CondOTProblem": - """Solve optimal transport problem. - - Parameters - ---------- - backend - Which backend to use, see :func:`moscot.backends.utils.get_available_backends`. - device - Device where to transfer the solution, see :meth:`moscot.base.output.BaseDiscreteSolverOutput.to`. - kwargs - Keyword arguments for :meth:`moscot.base.solver.BaseSolver.__call__`. - - - Returns - ------- - Self and modifies the following attributes: - - :attr:`solver`: optimal transport solver. - - :attr:`solution`: optimal transport solution. - """ - tmp = next(iter(self.distributions)) # type: ignore[arg-type] - input_dim = self.distributions[tmp].xy.shape[1] # type: ignore[union-attr, index] - cond_dim = self.distributions[tmp].conditions.shape[1] # type: ignore[union-attr, index] - - solver_class = backends.get_solver( - self.problem_kind, solver_name=solver_name, backend=backend, return_class=True - ) - init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs) - self._solver = solver_class(input_dim=input_dim, cond_dim=cond_dim, **init_kwargs) - # note that the solver call consists of solver._prepare and solver._solve - sample_pairs = self._sample_pairs if self._sample_pairs is not None else [] - self._solution = self._solver( # type: ignore[misc] - device=device, - distributions=self.distributions, - sample_pairs=self._sample_pairs, - is_conditional=len(sample_pairs) > 1, - **call_kwargs, - ) - - return self - - def _create_marginals( - self, adata: AnnData, *, source: bool, data: Optional[str] = None, **kwargs: Any - ) -> ArrayLike: - if data is True: - marginals = self.estimate_marginals(adata, source=source, **kwargs) - elif data in (False, None): - marginals = np.ones((adata.n_obs,), dtype=float) / adata.n_obs - elif isinstance(data, str): - try: - marginals = np.asarray(adata.obs[data], dtype=float) - except KeyError: - raise KeyError(f"Unable to find data in `adata.obs[{data!r}]`.") from None - return marginals - - def _create_mask(self, value: Union[K, Sequence[K]], *, allow_empty: bool = False) -> ArrayLike: - """Create a mask used to subset the data. - - TODO(@MUCDK): this is copied from SubsetPolicy, consider making this a function. - - Parameters - ---------- - value - Values in the data which determine the mask. - allow_empty - Whether to allow empty mask. - - Returns - ------- - Boolean mask of the same shape as the data. - """ - if isinstance(value, str) or not isinstance(value, Iterable): - mask = self._distribution_id == value - else: - mask = self._distribution_id.isin(value) - if not allow_empty and not np.sum(mask): - raise ValueError("Unable to construct an empty mask, use `allow_empty=True` to override.") - return np.asarray(mask) - - @property - def distributions(self) -> Optional[DistributionCollection[K]]: - """Collection of distributions.""" - return self._distributions - - @property - def adata(self) -> AnnData: - """Source annotated data object.""" - return self._adata - - @property - def solution(self) -> Optional[BaseDiscreteSolverOutput]: - """Solution of the optimal transport problem.""" - return self._solution - - @property - def solver(self) -> Optional[OTSolver[BaseDiscreteSolverOutput]]: - """Solver of the optimal transport problem.""" - return self._solver - - @property - def policy(self) -> Optional[SubsetPolicy[Any]]: - """Policy used to subset the data.""" - return self._policy diff --git a/src/moscot/base/solver.py b/src/moscot/base/solver.py index d4ec22360..b2068007f 100644 --- a/src/moscot/base/solver.py +++ b/src/moscot/base/solver.py @@ -16,13 +16,13 @@ from moscot._logging import logger from moscot._types import ArrayLike, Device_t, ProblemKind_t -from moscot.base.output import BaseDiscreteSolverOutput +from moscot.base.output import BaseSolverOutput from moscot.utils.tagged_array import Tag, TaggedArray __all__ = ["BaseSolver", "OTSolver"] -O = TypeVar("O", bound=BaseDiscreteSolverOutput) +O = TypeVar("O", bound=BaseSolverOutput) class TaggedArrayData(NamedTuple): # noqa: D101 diff --git a/src/moscot/neural/base/__init__.py b/src/moscot/neural/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/moscot/neural/base/problems/__init__.py b/src/moscot/neural/base/problems/__init__.py new file mode 100644 index 000000000..ec15beb21 --- /dev/null +++ b/src/moscot/neural/base/problems/__init__.py @@ -0,0 +1,3 @@ +from moscot.neural.base.problems.problem import NeuralOTProblem + +__all__ = ["NeuralOTProblem"] diff --git a/src/moscot/neural/base/problems/problem.py b/src/moscot/neural/base/problems/problem.py new file mode 100644 index 000000000..cc142f989 --- /dev/null +++ b/src/moscot/neural/base/problems/problem.py @@ -0,0 +1,243 @@ +from typing import ( + Any, + Hashable, + Iterable, + List, + Literal, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) + +import numpy as np +import pandas as pd + +from anndata import AnnData + +from moscot import backends +from moscot._types import ArrayLike, Device_t +from moscot.base.output import BaseNeuralOutput +from moscot.base.problems._utils import wrap_prepare, wrap_solve +from moscot.base.problems.problem import BaseProblem +from moscot.base.solver import OTSolver +from moscot.utils.subset_policy import ( # type:ignore[attr-defined] + ExplicitPolicy, + Policy_t, + StarPolicy, + SubsetPolicy, + create_policy, +) +from moscot.utils.tagged_array import DistributionCollection, DistributionContainer + +K = TypeVar("K", bound=Hashable) + +__all__ = ["NeuralOTProblem"] + + +class NeuralOTProblem(BaseProblem): # TODO(@MUCDK) check generic types, save and load + """ + Base class for all conditional (nerual) optimal transport problems. + + Parameters + ---------- + adata + Source annotated data object. + kwargs + Keyword arguments for :class:`moscot.base.problems.problem.BaseProblem` + """ + + def __init__( + self, + adata: AnnData, + **kwargs: Any, + ): + super().__init__(**kwargs) + self._adata = adata + + self._distributions: Optional[DistributionCollection[K]] = None # type: ignore[valid-type] + self._policy: Optional[SubsetPolicy[Any]] = None + self._sample_pairs: Optional[List[Tuple[Any, Any]]] = None + + self._solver: Optional[OTSolver[BaseNeuralOutput]] = None + self._solution: Optional[BaseNeuralOutput] = None + + self._a: Optional[str] = None + self._b: Optional[str] = None + + @wrap_prepare + def prepare( + self, + policy_key: str, + policy: Policy_t, + xy: Mapping[str, Any], + xx: Mapping[str, Any], + conditions: Mapping[str, Any], + a: Optional[str] = None, + b: Optional[str] = None, + subset: Optional[Sequence[Tuple[K, K]]] = None, + reference: K = None, + **kwargs: Any, + ) -> "NeuralOTProblem": + """Prepare conditional optimal transport problem. + + Parameters + ---------- + xy + Geometry defining the linear term. If passed as a :class:`dict`, + :meth:`~moscot.utils.tagged_array.TaggedArray.from_adata` will be called. + policy + Policy defining which pairs of distributions to sample from during training. + policy_key + %(key)s + a + Source marginals. + b + Target marginals. + kwargs + Keyword arguments when creating the source/target marginals. + + + Returns + ------- + Self and modifies the following attributes: + TODO. + """ + self._problem_kind = "linear" + self._distributions = DistributionCollection() + self._solution = None + self._policy_key = policy_key + try: + self._distribution_id = pd.Series(self.adata.obs[policy_key]) + except KeyError: + raise KeyError(f"Unable to find data in `adata.obs[{policy_key!r}]`.") from None + + self._policy = create_policy(policy, adata=self.adata, key=policy_key) + if isinstance(self._policy, ExplicitPolicy): + self._policy = self._policy.create_graph(subset=subset) + elif isinstance(self._policy, StarPolicy): + self._policy = self._policy.create_graph(reference=reference) + else: + _ = self.policy.create_graph() # type: ignore[union-attr] + self._sample_pairs = list(self.policy._graph) # type: ignore[union-attr] + + for el in self.policy.categories: # type: ignore[union-attr] + adata_masked = self.adata[self._create_mask(el)] + a_created = self._create_marginals(adata_masked, data=a, source=True, **kwargs) + b_created = self._create_marginals(adata_masked, data=b, source=False, **kwargs) + self.distributions[el] = DistributionContainer.from_adata( # type: ignore[index] + adata_masked, a=a_created, b=b_created, **xy, **xx, **conditions + ) + return self + + @wrap_solve + def solve( + self, + backend: Literal["ott"] = "ott", + solver_name: Literal["GENOTLinSolver"] = "GENOTLinSolver", + device: Optional[Device_t] = None, + **kwargs: Any, + ) -> "NeuralOTProblem": + """Solve optimal transport problem. + + Parameters + ---------- + backend + Which backend to use, see :func:`moscot.backends.utils.get_available_backends`. + device + Device where to transfer the solution, see :meth:`moscot.base.output.BaseNeuralOutput.to`. + kwargs + Keyword arguments for :meth:`moscot.base.solver.BaseSolver.__call__`. + + + Returns + ------- + Self and modifies the following attributes: + - :attr:`solver`: optimal transport solver. + - :attr:`solution`: optimal transport solution. + """ + tmp = next(iter(self.distributions)) # type: ignore[arg-type] + input_dim = self.distributions[tmp].xy.shape[1] # type: ignore[union-attr, index] + cond_dim = self.distributions[tmp].conditions.shape[1] # type: ignore[union-attr, index] + + solver_class = backends.get_solver( + self.problem_kind, solver_name=solver_name, backend=backend, return_class=True + ) + init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs) + self._solver = solver_class(input_dim=input_dim, cond_dim=cond_dim, **init_kwargs) + # note that the solver call consists of solver._prepare and solver._solve + sample_pairs = self._sample_pairs if self._sample_pairs is not None else [] + self._solution = self._solver( # type: ignore[misc] + device=device, + distributions=self.distributions, + sample_pairs=self._sample_pairs, + is_conditional=len(sample_pairs) > 1, + **call_kwargs, + ) + + return self + + def _create_marginals( + self, adata: AnnData, *, source: bool, data: Optional[str] = None, **kwargs: Any + ) -> ArrayLike: + if data is True: + marginals = self.estimate_marginals(adata, source=source, **kwargs) + elif data in (False, None): + marginals = np.ones((adata.n_obs,), dtype=float) / adata.n_obs + elif isinstance(data, str): + try: + marginals = np.asarray(adata.obs[data], dtype=float) + except KeyError: + raise KeyError(f"Unable to find data in `adata.obs[{data!r}]`.") from None + return marginals + + def _create_mask(self, value: Union[K, Sequence[K]], *, allow_empty: bool = False) -> ArrayLike: + """Create a mask used to subset the data. + + TODO(@MUCDK): this is copied from SubsetPolicy, consider making this a function. + + Parameters + ---------- + value + Values in the data which determine the mask. + allow_empty + Whether to allow empty mask. + + Returns + ------- + Boolean mask of the same shape as the data. + """ + if isinstance(value, str) or not isinstance(value, Iterable): + mask = self._distribution_id == value + else: + mask = self._distribution_id.isin(value) + if not allow_empty and not np.sum(mask): + raise ValueError("Unable to construct an empty mask, use `allow_empty=True` to override.") + return np.asarray(mask) + + @property + def distributions(self) -> Optional[DistributionCollection[K]]: + """Collection of distributions.""" + return self._distributions + + @property + def adata(self) -> AnnData: + """Source annotated data object.""" + return self._adata + + @property + def solution(self) -> Optional[BaseNeuralOutput]: + """Solution of the optimal transport problem.""" + return self._solution + + @property + def solver(self) -> Optional[OTSolver[BaseNeuralOutput]]: + """Solver of the optimal transport problem.""" + return self._solver + + @property + def policy(self) -> Optional[SubsetPolicy[Any]]: + """Policy used to subset the data.""" + return self._policy diff --git a/src/moscot/neural/problems/__init__.py b/src/moscot/neural/problems/__init__.py new file mode 100644 index 000000000..cd884d366 --- /dev/null +++ b/src/moscot/neural/problems/__init__.py @@ -0,0 +1,3 @@ +from moscot.neural.problems.generic import GENOTLinProblem + +__all__ = ["GENOTLinProblem"] diff --git a/src/moscot/neural/problems/generic/__init__.py b/src/moscot/neural/problems/generic/__init__.py new file mode 100644 index 000000000..657b4ea6c --- /dev/null +++ b/src/moscot/neural/problems/generic/__init__.py @@ -0,0 +1,3 @@ +from moscot.neural.problems.generic._generic import GENOTLinProblem + +__all__ = ["GENOTLinProblem"] diff --git a/src/moscot/neural/problems/generic/_generic.py b/src/moscot/neural/problems/generic/_generic.py new file mode 100644 index 000000000..c37a2f75c --- /dev/null +++ b/src/moscot/neural/problems/generic/_generic.py @@ -0,0 +1,78 @@ +import types +from types import MappingProxyType +from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Type, Union + +from moscot import _constants +from moscot._types import CostKwargs_t, OttCostFn_t, Policy_t +from moscot.neural.base.problems.problem import NeuralOTProblem +from moscot.problems._utils import ( + handle_conditional_attr, + handle_cost_tmp, + handle_joint_attr_tmp, +) + +__all__ = ["GENOTLinProblem"] + + +class GENOTLinProblem(NeuralOTProblem): + """Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems.""" + + def prepare( + self, + key: str, + joint_attr: Union[str, Mapping[str, Any]], + conditional_attr: Union[str, Mapping[str, Any]], + policy: Literal["sequential", "star", "explicit"] = "sequential", + a: Optional[str] = None, + b: Optional[str] = None, + cost: OttCostFn_t = "sq_euclidean", + cost_kwargs: CostKwargs_t = types.MappingProxyType({}), + **kwargs: Any, + ) -> "GENOTLinProblem": + """Prepare the :class:`moscot.problems.generic.GENOTLinProblem`.""" + self.batch_key = key + xy, kwargs = handle_joint_attr_tmp(joint_attr, kwargs) + conditions = handle_conditional_attr(conditional_attr) + xy, xx = handle_cost_tmp(xy=xy, x={}, y={}, cost=cost, cost_kwargs=cost_kwargs) + return super().prepare( + policy_key=key, + policy=policy, + xy=xy, + xx=xx, + conditions=conditions, + a=a, + b=b, + **kwargs, + ) + + def solve( + self, + batch_size: int = 1024, + seed: int = 0, + iterations: int = 25000, # TODO(@MUCDK): rename to max_iterations + valid_freq: int = 50, + valid_sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}), + train_size: float = 1.0, + **kwargs: Any, + ) -> "GENOTLinProblem": + """Solve.""" + return super().solve( + batch_size=batch_size, + # tau_a=tau_a, # TODO: unbalancedness handler + # tau_b=tau_b, + seed=seed, + n_iters=iterations, + valid_freq=valid_freq, + valid_sinkhorn_kwargs=valid_sinkhorn_kwargs, + train_size=train_size, + solver_name="GENOTLinSolver", + **kwargs, + ) + + @property + def _base_problem_type(self) -> Type[NeuralOTProblem]: + return NeuralOTProblem + + @property + def _valid_policies(self) -> Tuple[Policy_t, ...]: + return _constants.SEQUENTIAL, _constants.EXPLICIT # type: ignore[return-value] diff --git a/src/moscot/problems/__init__.py b/src/moscot/problems/__init__.py index 14f4422f5..b96b993b4 100644 --- a/src/moscot/problems/__init__.py +++ b/src/moscot/problems/__init__.py @@ -1,5 +1,4 @@ from moscot.problems.cross_modality import TranslationProblem -from moscot.problems.generic import GENOTLinProblem from moscot.problems.space import AlignmentProblem, MappingProblem from moscot.problems.spatiotemporal import SpatioTemporalProblem from moscot.problems.time import LineageProblem, TemporalProblem @@ -11,5 +10,4 @@ "SpatioTemporalProblem", "LineageProblem", "TemporalProblem", - "GENOTLinProblem", ] diff --git a/src/moscot/problems/generic/__init__.py b/src/moscot/problems/generic/__init__.py index d96dc4db6..ef9b78951 100644 --- a/src/moscot/problems/generic/__init__.py +++ b/src/moscot/problems/generic/__init__.py @@ -1,14 +1,8 @@ -from moscot.problems.generic._generic import ( - FGWProblem, - GENOTLinProblem, - GWProblem, - SinkhornProblem, -) +from moscot.problems.generic._generic import FGWProblem, GWProblem, SinkhornProblem from moscot.problems.generic._mixins import GenericAnalysisMixin __all__ = [ "FGWProblem" "SinkhornProblem", - "GENOTLinProblem", "GWProblem", "GenericAnalysisMixin", ] diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index fcd3d8b2e..62dfbcf68 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -1,5 +1,4 @@ import types -from types import MappingProxyType from typing import Any, Dict, Literal, Mapping, Optional, Sequence, Tuple, Type, Union from anndata import AnnData @@ -16,17 +15,11 @@ SinkhornInitializer_t, ) from moscot.base.problems.compound_problem import B, Callback_t, CompoundProblem, K -from moscot.base.problems.problem import CondOTProblem, OTProblem -from moscot.problems._utils import ( - handle_conditional_attr, - handle_cost, - handle_cost_tmp, - handle_joint_attr, - handle_joint_attr_tmp, -) +from moscot.base.problems.problem import OTProblem +from moscot.problems._utils import handle_cost, handle_joint_attr from moscot.problems.generic._mixins import GenericAnalysisMixin -__all__ = ["SinkhornProblem", "GWProblem", "GENOTLinProblem", "FGWProblem"] +__all__ = ["SinkhornProblem", "GWProblem", "FGWProblem"] def set_quad_defaults(z: Optional[Union[str, Mapping[str, Any]]]) -> Dict[str, str]: @@ -37,7 +30,7 @@ def set_quad_defaults(z: Optional[Union[str, Mapping[str, Any]]]) -> Dict[str, s raise TypeError("`x_attr` and `y_attr` must be of type `str` or `dict` if no callback is provided.") -class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc] +class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): """Class for solving a :term:`linear problem`. Parameters @@ -264,7 +257,7 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]: return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value] -class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc] +class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): """Class for solving the :term:`GW ` or :term:`FGW ` problems. Parameters @@ -774,67 +767,3 @@ def _base_problem_type(self) -> Type[B]: @property def _valid_policies(self) -> Tuple[Policy_t, ...]: return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value] - - -class GENOTLinProblem(CondOTProblem): - """Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems.""" - - def prepare( - self, - key: str, - joint_attr: Union[str, Mapping[str, Any]], - conditional_attr: Union[str, Mapping[str, Any]], - policy: Literal["sequential", "star", "explicit"] = "sequential", - a: Optional[str] = None, - b: Optional[str] = None, - cost: OttCostFn_t = "sq_euclidean", - cost_kwargs: CostKwargs_t = types.MappingProxyType({}), - **kwargs: Any, - ) -> "GENOTLinProblem": - """Prepare the :class:`moscot.problems.generic.GENOTLinProblem`.""" - self.batch_key = key - xy, kwargs = handle_joint_attr_tmp(joint_attr, kwargs) - conditions = handle_conditional_attr(conditional_attr) - xy, xx = handle_cost_tmp(xy=xy, x={}, y={}, cost=cost, cost_kwargs=cost_kwargs) - return super().prepare( - policy_key=key, - policy=policy, - xy=xy, - xx=xx, - conditions=conditions, - a=a, - b=b, - **kwargs, - ) - - def solve( - self, - batch_size: int = 1024, - seed: int = 0, - iterations: int = 25000, # TODO(@MUCDK): rename to max_iterations - valid_freq: int = 50, - valid_sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}), - train_size: float = 1.0, - **kwargs: Any, - ) -> "GENOTLinProblem": - """Solve.""" - return super().solve( - batch_size=batch_size, - # tau_a=tau_a, # TODO: unbalancedness handler - # tau_b=tau_b, - seed=seed, - n_iters=iterations, - valid_freq=valid_freq, - valid_sinkhorn_kwargs=valid_sinkhorn_kwargs, - train_size=train_size, - solver_name="GENOTLinSolver", - **kwargs, - ) - - @property - def _base_problem_type(self) -> Type[CondOTProblem]: - return CondOTProblem - - @property - def _valid_policies(self) -> Tuple[Policy_t, ...]: - return _constants.SEQUENTIAL, _constants.EXPLICIT # type: ignore[return-value] diff --git a/src/moscot/problems/space/_mapping.py b/src/moscot/problems/space/_mapping.py index 2994de5cd..6fd5481d1 100644 --- a/src/moscot/problems/space/_mapping.py +++ b/src/moscot/problems/space/_mapping.py @@ -65,8 +65,8 @@ def _create_problem( adata_tgt=self.adata_sc, src_obs_mask=src_mask, tgt_obs_mask=None, - src_var_mask=self.filtered_vars, # type: ignore[arg-type] - tgt_var_mask=self.filtered_vars, # type: ignore[arg-type] + src_var_mask=self.filtered_vars, + tgt_var_mask=self.filtered_vars, src_key=src, tgt_key=tgt, **kwargs, diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 6181de0cf..1bfc652e7 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -721,7 +721,7 @@ def spatial_key(self, key: Optional[str]) -> None: def _compute_correspondence( spatial: ArrayLike, - features: ArrayLike, + features: Union[ArrayLike, sp.spmatrix, sp.sparray], interval: Union[int, ArrayLike] = 10, max_dist: Optional[int] = None, ) -> pd.DataFrame: @@ -743,7 +743,8 @@ def pdist(row_idx: ArrayLike, col_idx: float, feat: ArrayLike) -> Any: # TODO(michalk8): vectorize using jax, this is just a for loop vpdist = np.vectorize(pdist, excluded=["feat"]) if sp.issparse(features): - features = features.toarray() # type: ignore[attr-defined] + features_sp: Union[sp.spmatrix, sp.sparray] = features + features = features_sp.toarray() feat_arr, index_arr, support_arr = [], [], [] for ind, i in enumerate(support): diff --git a/src/moscot/problems/time/_lineage.py b/src/moscot/problems/time/_lineage.py index 3fcfd53c9..fe95950dc 100644 --- a/src/moscot/problems/time/_lineage.py +++ b/src/moscot/problems/time/_lineage.py @@ -23,7 +23,7 @@ __all__ = ["TemporalProblem", "LineageProblem"] -class TemporalProblem( # type: ignore[misc] +class TemporalProblem( TemporalMixin[Numeric_t, BirthDeathProblem], BirthDeathMixin, CompoundProblem[Numeric_t, BirthDeathProblem] ): """Class for analyzing time-series single cell data based on :cite:`schiebinger:19`. diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index dee020f35..bc650345d 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -545,7 +545,7 @@ def _get_data( intermediate_data.astype(np.float64) if intermediate_data is not None else None, intermediate_adata, target_data.astype(np.float64) if target_data is not None else None, - ) # type: ignore[return-value] + ) def compute_interpolated_distance( self, diff --git a/tests/problems/generic/test_conditional_neural_problem.py b/tests/neural/problems/generic/test_conditional_neural_problem.py similarity index 92% rename from tests/problems/generic/test_conditional_neural_problem.py rename to tests/neural/problems/generic/test_conditional_neural_problem.py index 5a7297de0..e4cd9b832 100644 --- a/tests/problems/generic/test_conditional_neural_problem.py +++ b/tests/neural/problems/generic/test_conditional_neural_problem.py @@ -6,9 +6,9 @@ import anndata as ad -from moscot.base.output import BaseDiscreteSolverOutput -from moscot.base.problems import CondOTProblem -from moscot.problems.generic import GENOTLinProblem # type: ignore[attr-defined] +from moscot.base.output import BaseSolverOutput +from moscot.neural.base.problems import NeuralOTProblem +from moscot.neural.problems.generic import GENOTLinProblem # type: ignore[attr-defined] from moscot.utils.tagged_array import DistributionCollection, DistributionContainer from tests._utils import ATOL, RTOL from tests.problems.conftest import neurallin_cond_args_1 @@ -19,7 +19,7 @@ class TestGENOTLinProblem: def test_prepare(self, adata_time: ad.AnnData): problem = GENOTLinProblem(adata=adata_time) problem = problem.prepare(key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}) - assert isinstance(problem, CondOTProblem) + assert isinstance(problem, NeuralOTProblem) assert isinstance(problem.distributions, DistributionCollection) assert list(problem.distributions.keys()) == [0, 1, 2] @@ -43,7 +43,7 @@ def test_solve_balanced_no_baseline(self, adata_time: ad.AnnData, train_size: fl problem = GENOTLinProblem(adata=adata_time) problem = problem.prepare(key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}) problem = problem.solve(train_size=train_size, **neurallin_cond_args_1) - assert isinstance(problem.solution, BaseDiscreteSolverOutput) + assert isinstance(problem.solution, BaseSolverOutput) def test_reproducibility(self, adata_time: ad.AnnData): cond_zero_mask = np.array(adata_time.obs["time"] == 0)