Skip to content

Commit

Permalink
Improve GPU handling in linop.xray.astra (#505)
Browse files Browse the repository at this point in the history
* Add utlity script

* Fix astra installation instructions

* Add astra utility function

* Fix typing errors (probably in copy from admm module)

* Minor comment fix

* Docs fixes/improvements

* Remove error when jax device is cpu: gpu can be available to astra but not jax

* Improve cpu/gpu selection

* Fix tests
  • Loading branch information
bwohlberg authored Feb 20, 2024
1 parent 6e46ca2 commit 19f8668
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 39 deletions.
2 changes: 1 addition & 1 deletion docs/source/include/examplenotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ which should be installed via conda:

::

conda install -c astra-toolbox astra-toolbox
conda install astra-toolbox
pip install -r examples/examples_requirements.txt # Installs other example requirements

The dependencies can also be installed individually as required.
Expand Down
32 changes: 22 additions & 10 deletions docs/source/notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,6 @@
Notes
*****

No GPU/TPU Warning
==================

JAX currently issues a warning when used on a platform without a
GPU. To disable this warning, set the environment variable
``JAX_PLATFORM_NAME=cpu`` before running Python. This warning is
suppressed by SCICO for JAX versions after 0.3.23, making use of
the environment variable unnecessary.


Debugging
=========
Expand Down Expand Up @@ -44,6 +35,26 @@ can be enabled in one of two ways:

For more information, see the `JAX notes on double precision <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision>`_.

Device Control
==============

Use of the CPU device can be forced even when GPUs are present by setting the
environment variable ``JAX_PLATFORM_NAME=cpu`` before running Python. This also
serves to disable the warning that older versions of JAX issued when running
on a platform without a GPU, but this should no longer be necessary for any
JAX versions supported by SCICO.

By default, JAX views a multi-core CPU as a single device. Primarily for testing
purposes, it may be useful to instruct JAX to emulate multiple CPU devices, by
setting the environment variable ``XLA_FLAGS='--xla_force_host_platform_device_count=<n>'``,
where ``<n>`` is an integer number of devices. For more detail see the relevant
`section of the JAX docs <https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#aside-hosts-and-devices-in-jax>`__.

By default, JAX will preallocate a large chunk of GPU memory on startup. This
behavior can be controlled using environment variables ``XLA_PYTHON_CLIENT_PREALLOCATE``,
``XLA_PYTHON_CLIENT_MEM_FRACTION``, and ``XLA_PYTHON_CLIENT_ALLOCATOR``, as described in
the relevant `section of the JAX docs <https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html>`__.


Random Number Generation
========================
Expand Down Expand Up @@ -129,7 +140,8 @@ SCICO, while the other two depend on external packages.
The :class:`.xray.svmbir.XRayTransform` class is implemented
via an interface to the `svmbir
<https://svmbir.readthedocs.io/en/latest/>`__ package. The
:class:`.xray.astra.XRayTransform` class is implemented via an
:class:`.xray.astra.XRayTransform2D` and
:class:`.xray.astra.XRayTransform3D` classes are implemented via an
interface to the `ASTRA toolbox
<https://www.astra-toolbox.com/>`__. This toolbox does provide some
GPU acceleration support, but efficiency is expected to be lower than
Expand Down
24 changes: 24 additions & 0 deletions misc/availgpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python

# Determine which GPUs available for use and recommend CUDA_VISIBLE_DEVICES
# setting if any are already in use.

# pylint: disable=missing-module-docstring


import GPUtil

print("GPU utlizitation")
GPUtil.showUtilization()

devIDs = GPUtil.getAvailable(
order="first", limit=65536, maxLoad=0.1, maxMemory=0.1, includeNan=False
)

Ngpu = len(GPUtil.getGPUs())
if len(devIDs) == Ngpu:
print(f"All {Ngpu} GPUs available for use")
else:
print(f"Only {len(devIDs)} of {Ngpu} GPUs available for use")
print("To avoid attempting to use GPUs already in use, run the command")
print(f" export CUDA_VISIBLE_DEVICES={','.join(map(str, devIDs))}")
2 changes: 1 addition & 1 deletion scico/flax/examples/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def generate_ct_data(
# Normalize sinogram
sino = sino / size

# Compute filter back-project in parallel.
# Compute filtered back projection in parallel.
afbp_map = lambda v: jnp.atleast_3d(A.fbp(v.squeeze()))
start_time = time()
fbpshd = jax.pmap(lambda i: jax.lax.map(afbp_map, sinoshd[i]))(jnp.arange(nproc))
Expand Down
37 changes: 27 additions & 10 deletions scico/linop/xray/astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
not available.
"""

from typing import List, Optional, Tuple
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np

Expand All @@ -33,12 +33,28 @@
else:
raise e

try:
from collections import Iterable # type: ignore
except ImportError:
import collections

# Monkey patching required because latest astra release uses old module path for Iterable
collections.Iterable = collections.abc.Iterable # type: ignore

from scico.typing import Shape

from .._linop import LinearOperator


def set_astra_gpu_index(idx: Union[int, Sequence[int]]):
"""Set the index/indices of GPU(s) to be used by astra.
Args:
idx: Index or indices of GPU(s).
"""
astra.set_gpu_index(idx)


class XRayTransform2D(LinearOperator):
r"""2D parallel beam X-ray transform based on the ASTRA toolbox.
Expand Down Expand Up @@ -108,11 +124,16 @@ def __init__(
"Please see the astra documentation for details."
)

dev0 = jax.devices()[0]
if dev0.platform == "cpu" or device == "cpu":
self.device = "cpu"
elif dev0.platform == "gpu" and device in ["gpu", "auto"]:
self.device = "gpu"
if device in ["cpu", "gpu"]:
# If cpu or gpu selected, attempt to comply (no checking to
# confirm that a gpu is available to astra).
self.device = device
elif device == "auto":
# If auto selected, use cpu or gpu depending on the default
# jax device (for simplicity, no checking whether gpu is
# available to astra when one is not available to jax).
dev0 = jax.devices()[0]
self.device = dev0.platform
else:
raise ValueError(f"Invalid device specified; got {device}.")

Expand Down Expand Up @@ -280,10 +301,6 @@ def __init__(
self.input_shape: tuple = input_shape
self.vol_geom = astra.create_vol_geom(input_shape[1], input_shape[2], input_shape[0])

dev0 = jax.devices()[0]
if dev0.platform == "cpu":
raise ValueError("No CPU algorithm for 3D projection and GPU not available.")

# Wrap our non-jax function to indicate we will supply fwd/rev mode functions
self._eval = jax.custom_vjp(self._proj)
self._eval.defvjp(lambda x: (self._proj(x), None), lambda _, y: (self._bproj(y),)) # type: ignore
Expand Down
6 changes: 3 additions & 3 deletions scico/optimize/_ladmm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2021-2023 by SCICO Developers
# Copyright (C) 2021-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -11,7 +11,7 @@
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union

import scico.numpy as snp
from scico.functional import Functional
Expand Down Expand Up @@ -149,7 +149,7 @@ def minimizer(self):
def objective(
self,
x: Optional[Union[Array, BlockArray]] = None,
z: Optional[List[Union[Array, BlockArray]]] = None,
z: Optional[Union[Array, BlockArray]] = None,
) -> float:
r"""Evaluate the objective function.
Expand Down
10 changes: 5 additions & 5 deletions scico/optimize/_padmm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# Copyright (C) 2022-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -11,7 +11,7 @@
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union

import scico.numpy as snp
from scico import cvjp, jvp
Expand Down Expand Up @@ -144,7 +144,7 @@ def minimizer(self):
def objective(
self,
x: Optional[Union[Array, BlockArray]] = None,
z: Optional[List[Union[Array, BlockArray]]] = None,
z: Optional[Union[Array, BlockArray]] = None,
) -> float:
r"""Evaluate the objective function.
Expand Down Expand Up @@ -289,7 +289,7 @@ def __init__(
def norm_primal_residual(
self,
x: Optional[Union[Array, BlockArray]] = None,
z: Optional[List[Union[Array, BlockArray]]] = None,
z: Optional[Union[Array, BlockArray]] = None,
) -> float:
r"""Compute the :math:`\ell_2` norm of the primal residual.
Expand Down Expand Up @@ -507,7 +507,7 @@ def __init__(
def norm_primal_residual(
self,
x: Optional[Union[Array, BlockArray]] = None,
z: Optional[List[Union[Array, BlockArray]]] = None,
z: Optional[Union[Array, BlockArray]] = None,
) -> float:
r"""Compute the :math:`\ell_2` norm of the primal residual.
Expand Down
9 changes: 0 additions & 9 deletions scico/test/linop/xray/test_astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,6 @@ def test_jit_in_DiagonalStack():
H.T @ snp.zeros(H.output_shape, dtype=snp.float32)


@pytest.mark.skipif(jax.devices()[0].platform != "cpu", reason="checking CPU behavior")
def test_3D_on_CPU():
x = snp.zeros((4, 5, 6))
with pytest.raises(ValueError):
A = XRayTransform3D(
x.shape, det_count=[6, 6], det_spacing=[1.0, 1.0], angles=snp.linspace(0, snp.pi, 10)
)


@pytest.mark.skipif(jax.devices()[0].platform != "gpu", reason="checking GPU behavior")
def test_3D_on_GPU():
x = snp.zeros((4, 5, 6))
Expand Down

0 comments on commit 19f8668

Please sign in to comment.