Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add diagonal operator mapping base operator over an array axis #521

Merged
merged 13 commits into from
Jun 25, 2024
Merged
3 changes: 2 additions & 1 deletion scico/linop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._func import Crop, Pad, Reshape, Slice, Sum, Transpose, linop_from_function
from ._linop import ComposedLinearOperator, LinearOperator
from ._matrix import MatrixOperator
from ._stack import DiagonalStack, VerticalStack, linop_over_axes
from ._stack import DiagonalReplicated, DiagonalStack, VerticalStack, linop_over_axes
from ._util import jacobian, operator_norm, power_iteration, valid_adjoint
from .xray import Parallel2dProjector, XRayTransform

Expand All @@ -29,6 +29,7 @@
"FiniteDifference",
"SingleAxisFiniteDifference",
"Identity",
"DiagonalReplicated",
"VerticalStack",
"DiagonalStack",
"MatrixOperator",
Expand Down
78 changes: 78 additions & 0 deletions scico/linop/_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import scico.numpy as snp
from scico.numpy import Array, BlockArray
from scico.numpy.util import normalize_axes
from scico.operator._stack import DiagonalReplicated as DReplicated
from scico.operator._stack import DiagonalStack as DStack
from scico.operator._stack import VerticalStack as VStack
from scico.typing import Axes, Shape
Expand Down Expand Up @@ -146,6 +147,83 @@ def _adj(self, y: Union[Array, BlockArray]) -> Union[Array, BlockArray]: # type
return snp.blockarray(result)


class DiagonalReplicated(DReplicated, LinearOperator):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It takes a second to figure out what DReplicated is here. Consider DiagonalReplicatedOperator?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you noticed that this is an alias for operator.DiagonalReplicated. You don't have similar concerns for DStack and VStack? If DReplicated is changed as suggested, it would seem to make sense to modify them too for consistency.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed as discussed.

r"""A diagonal stack constructed from a single linear operator.

Given linear operator :math:`A`, create the linear operator

.. math::
H =
\begin{pmatrix}
A & 0 & \ldots & 0\\
0 & A & \ldots & 0\\
\vdots & \vdots & \ddots & \vdots\\
0 & 0 & \ldots & A \\
\end{pmatrix} \qquad
\text{such that} \qquad
H
\begin{pmatrix}
\mb{x}_1 \\
\mb{x}_2 \\
\vdots \\
\mb{x}_N \\
\end{pmatrix}
=
\begin{pmatrix}
A(\mb{x}_1) \\
A(\mb{x}_2) \\
\vdots \\
A(\mb{x}_N) \\
\end{pmatrix} \;.

The application of :math:`A` to each component :math:`\mb{x}_k` is
computed using :func:`jax.pmap` or :func:`jax.vmap`. The input shape
for linear operator :math:`A` should exclude the array axis on which
:math:`A` is replicated to form :math:`H`. For example, if :math:`A`
has input shape `(3, 4)` and :math:`H` is constructed to replicate
on axis 0 with 2 replicates, the input shape of :math:`H` will be
`(2, 3, 4)`.

Linear operators taking :class:`.BlockArray` input are not supported.
"""

def __init__(
self,
op: LinearOperator,
replicates: int,
input_axis: int = 0,
output_axis: Optional[int] = None,
map_type: str = "auto",
**kwargs,
):
"""
Args:
op: Linear operator to replicate.
replicates: Number of replicates of `op`.
input_axis: Input axis over which `op` should be replicated.
output_axis: Index of replication axis in output array.
If ``None``, the input replication axis is used.
map_type: If "pmap" or "vmap", apply replicated mapping using
:func:`jax.pmap` or :func:`jax.vmap` respectively. If
"auto", use :func:`jax.pmap` if sufficient devices are
available for the number of replicates, otherwise use
:func:`jax.vmap`.
"""
if not isinstance(op, LinearOperator):
raise TypeError("Argument op must be of type LinearOperator.")

super().__init__(
op,
replicates,
input_axis=input_axis,
output_axis=output_axis,
map_type=map_type,
**kwargs,
)

self._adj = self.jaxmap(op.adj, in_axes=self.input_axis, out_axes=self.output_axis)


def linop_over_axes(
linop: type[LinearOperator],
input_shape: Shape,
Expand Down
5 changes: 3 additions & 2 deletions scico/operator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2021-2023 by SCICO Developers
# Copyright (C) 2021-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -13,11 +13,12 @@
from ._operator import Operator
from .biconvolve import BiConvolve
from ._func import operator_from_function, Abs, Angle, Exp
from ._stack import DiagonalStack, VerticalStack
from ._stack import DiagonalStack, VerticalStack, DiagonalReplicated

__all__ = [
"Operator",
"BiConvolve",
"DiagonalReplicated",
"DiagonalStack",
"VerticalStack",
"operator_from_function",
Expand Down
108 changes: 107 additions & 1 deletion scico/operator/_stack.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2023 by SCICO Developers
# Copyright (C) 2023-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -13,6 +13,8 @@

import numpy as np

import jax

from typing_extensions import TypeGuard

import scico.numpy as snp
Expand Down Expand Up @@ -234,3 +236,107 @@ def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]:
if self.collapse_output:
return snp.stack(result)
return snp.blockarray(result)


class DiagonalReplicated(Operator):
r"""A diagonal stack constructed from a single operator.

Given operator :math:`A`, create the operator :math:`H` such that

.. math::
H \left(
\begin{pmatrix}
\mb{x}_1 \\
\mb{x}_2 \\
\vdots \\
\mb{x}_N \\
\end{pmatrix} \right)
=
\begin{pmatrix}
A(\mb{x}_1) \\
A(\mb{x}_2) \\
\vdots \\
A(\mb{x}_N) \\
\end{pmatrix} \;.

The application of :math:`A` to each component :math:`\mb{x}_k` is
computed using :func:`jax.pmap` or :func:`jax.vmap`. The input shape
for operator :math:`A` should exclude the array axis on which
:math:`A` is replicated to form :math:`H`. For example, if :math:`A`
has input shape `(3, 4)` and :math:`H` is constructed to replicate
on axis 0 with 2 replicates, the input shape of :math:`H` will be
`(2, 3, 4)`.

Operators taking :class:`.BlockArray` input are not supported.
"""

def __init__(
self,
op: Operator,
replicates: int,
input_axis: int = 0,
output_axis: Optional[int] = None,
map_type: str = "auto",
**kwargs,
):
"""
Args:
op: Operator to replicate.
replicates: Number of replicates of `op`.
input_axis: Input axis over which `op` should be replicated.
output_axis: Index of replication axis in output array.
If ``None``, the input replication axis is used.
map_type: If "pmap" or "vmap", apply replicated mapping using
:func:`jax.pmap` or :func:`jax.vmap` respectively. If
"auto", use :func:`jax.pmap` if sufficient devices are
available for the number of replicates, otherwise use
:func:`jax.vmap`.
"""
if map_type not in ["auto", "pmap", "vmap"]:
raise ValueError("Argument map_type must be one of 'auto', 'pmap, or 'vmap'.")
if input_axis < 0:
input_axis = len(op.input_shape) + 1 + input_axis
if input_axis < 0 or input_axis > len(op.input_shape):
raise ValueError(
"Argument input_axis must be positive and less than the number of axes "
"in the input shape of op."
)
if is_nested(op.input_shape):
raise ValueError("Argument op may not be an Operator taking BlockArray input.")
if is_nested(op.output_shape):
raise ValueError("Argument op may not be an Operator with BlockArray output.")
self.op = op
self.replicates = replicates
self.input_axis = input_axis
self.output_axis = self.input_axis if output_axis is None else output_axis

if map_type == "auto":
self.jaxmap = jax.pmap if replicates <= jax.device_count() else jax.vmap
else:
if map_type == "pmap" and replicates > jax.device_count():
raise ValueError(
"Requested pmap mapping but number of replicates exceeds device count."
)
else:
self.jaxmap = jax.pmap if map_type == "pmap" else jax.vmap

eval_fn = self.jaxmap(op.__call__, in_axes=self.input_axis, out_axes=self.output_axis)

input_shape = (
op.input_shape[0 : self.input_axis] + (replicates,) + op.input_shape[self.input_axis :]
)
output_shape = (
op.output_shape[0 : self.output_axis]
+ (replicates,)
+ op.output_shape[self.output_axis :]
)

super().__init__(
input_shape=input_shape, # type: ignore
output_shape=output_shape, # type: ignore
eval_fn=eval_fn,
input_dtype=op.input_dtype,
output_dtype=op.output_dtype,
jit=False,
**kwargs,
)
23 changes: 22 additions & 1 deletion scico/test/linop/test_linop_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@
import pytest

import scico.numpy as snp
from scico.linop import Convolve, DiagonalStack, Identity, Sum, VerticalStack
from scico.linop import (
Convolve,
DiagonalReplicated,
DiagonalStack,
Identity,
Sum,
VerticalStack,
)
from scico.operator import Abs
from scico.random import randn
from scico.test.linop.test_linop import adjoint_test


Expand Down Expand Up @@ -166,3 +174,16 @@ def test_output_collapse(self):

H = DiagonalStack((A1, A2), collapse_output=False)
assert H.output_shape == (S1, S1)


class TestDiagonalReplicated:
def setup_method(self, method):
self.key = jax.random.PRNGKey(12345)

def test_adjoint(self):
x, key = randn((2, 3, 4), key=self.key)
A = Sum(x.shape[1:], axis=-1)
D = DiagonalReplicated(A, x.shape[0])
y = D.T(D(x))
np.testing.assert_allclose(y[0], A.T(A(x[0])))
np.testing.assert_allclose(y[1], A.T(A(x[1])))
54 changes: 53 additions & 1 deletion scico/test/operator/test_op_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
import pytest

import scico.numpy as snp
from scico.operator import Abs, DiagonalStack, Operator, VerticalStack
from scico.operator import (
Abs,
DiagonalReplicated,
DiagonalStack,
Operator,
VerticalStack,
)
from scico.random import randn

TestOpA = Operator(input_shape=(3, 4), output_shape=(2, 3, 4), eval_fn=lambda x: snp.stack((x, x)))
TestOpB = Operator(
Expand Down Expand Up @@ -140,3 +147,48 @@ def test_output_collapse(self):

H = DiagonalStack((A1, A2), collapse_output=False)
assert H.output_shape == (A1.output_shape, A1.output_shape)


class TestDiagonalReplicated:
def setup_method(self, method):
self.key = jax.random.PRNGKey(12345)

@pytest.mark.parametrize("map_type", ["auto", "vmap"])
@pytest.mark.parametrize("input_axis", [0, 1])
def test_map_auto_vmap(self, input_axis, map_type):
x, key = randn((2, 3, 4), key=self.key)
mapshape = (3, 4) if input_axis == 0 else (2, 4)
replicates = x.shape[input_axis]
A = Abs(mapshape)
D = DiagonalReplicated(A, replicates, input_axis=input_axis, map_type=map_type)
y = D(x)
assert y.shape[input_axis] == replicates

@pytest.mark.skipif(jax.device_count() < 2, reason="multiple devices required for test")
def test_map_auto_pmap(self):
x, key = randn((2, 3, 4), key=self.key)
A = Abs(x.shape[1:])
replicates = x.shape[0]
D = DiagonalReplicated(A, replicates, map_type="pmap")
y = D(x)
assert y.shape[0] == replicates

def test_input_axis(self):
# Ensure that operators can be stacked on final axis
x, key = randn((2, 3, 4), key=self.key)
A = Abs(x.shape[0:2])
replicates = x.shape[2]
D = DiagonalReplicated(A, replicates, input_axis=2)
y = D(x)
assert y.shape == (2, 3, 4)
D = DiagonalReplicated(A, replicates, input_axis=-1)
y = D(x)
assert y.shape == (2, 3, 4)

def test_output_axis(self):
x, key = randn((2, 3, 4), key=self.key)
A = Abs(x.shape[1:])
replicates = x.shape[0]
D = DiagonalReplicated(A, replicates, output_axis=1)
y = D(x)
assert y.shape == (3, 2, 4)
Loading