diff --git a/docs/source/include/examplenotes.rst b/docs/source/include/examplenotes.rst index 0781478f0..ba067e857 100644 --- a/docs/source/include/examplenotes.rst +++ b/docs/source/include/examplenotes.rst @@ -9,7 +9,7 @@ which should be installed via conda: :: - conda install -c astra-toolbox astra-toolbox + conda install astra-toolbox pip install -r examples/examples_requirements.txt # Installs other example requirements The dependencies can also be installed individually as required. diff --git a/docs/source/notes.rst b/docs/source/notes.rst index 2986f5adc..6fdfa5c32 100644 --- a/docs/source/notes.rst +++ b/docs/source/notes.rst @@ -2,15 +2,6 @@ Notes ***** -No GPU/TPU Warning -================== - -JAX currently issues a warning when used on a platform without a -GPU. To disable this warning, set the environment variable -``JAX_PLATFORM_NAME=cpu`` before running Python. This warning is -suppressed by SCICO for JAX versions after 0.3.23, making use of -the environment variable unnecessary. - Debugging ========= @@ -44,6 +35,26 @@ can be enabled in one of two ways: For more information, see the `JAX notes on double precision `_. +Device Control +============== + +Use of the CPU device can be forced even when GPUs are present by setting the +environment variable ``JAX_PLATFORM_NAME=cpu`` before running Python. This also +serves to disable the warning that older versions of JAX issued when running +on a platform without a GPU, but this should no longer be necessary for any +JAX versions supported by SCICO. + +By default, JAX views a multi-core CPU as a single device. Primarily for testing +purposes, it may be useful to instruct JAX to emulate multiple CPU devices, by +setting the environment variable ``XLA_FLAGS='--xla_force_host_platform_device_count='``, +where ```` is an integer number of devices. For more detail see the relevant +`section of the JAX docs `__. + +By default, JAX will preallocate a large chunk of GPU memory on startup. This +behavior can be controlled using environment variables ``XLA_PYTHON_CLIENT_PREALLOCATE``, +``XLA_PYTHON_CLIENT_MEM_FRACTION``, and ``XLA_PYTHON_CLIENT_ALLOCATOR``, as described in +the relevant `section of the JAX docs `__. + Random Number Generation ======================== @@ -129,7 +140,8 @@ SCICO, while the other two depend on external packages. The :class:`.xray.svmbir.XRayTransform` class is implemented via an interface to the `svmbir `__ package. The -:class:`.xray.astra.XRayTransform` class is implemented via an +:class:`.xray.astra.XRayTransform2D` and +:class:`.xray.astra.XRayTransform3D` classes are implemented via an interface to the `ASTRA toolbox `__. This toolbox does provide some GPU acceleration support, but efficiency is expected to be lower than diff --git a/misc/availgpu.py b/misc/availgpu.py new file mode 100644 index 000000000..7ff4f2e16 --- /dev/null +++ b/misc/availgpu.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python + +# Determine which GPUs available for use and recommend CUDA_VISIBLE_DEVICES +# setting if any are already in use. + +# pylint: disable=missing-module-docstring + + +import GPUtil + +print("GPU utlizitation") +GPUtil.showUtilization() + +devIDs = GPUtil.getAvailable( + order="first", limit=65536, maxLoad=0.1, maxMemory=0.1, includeNan=False +) + +Ngpu = len(GPUtil.getGPUs()) +if len(devIDs) == Ngpu: + print(f"All {Ngpu} GPUs available for use") +else: + print(f"Only {len(devIDs)} of {Ngpu} GPUs available for use") + print("To avoid attempting to use GPUs already in use, run the command") + print(f" export CUDA_VISIBLE_DEVICES={','.join(map(str, devIDs))}") diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index 423bb82c9..bda52f17a 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -221,7 +221,7 @@ def generate_ct_data( # Normalize sinogram sino = sino / size - # Compute filter back-project in parallel. + # Compute filtered back projection in parallel. afbp_map = lambda v: jnp.atleast_3d(A.fbp(v.squeeze())) start_time = time() fbpshd = jax.pmap(lambda i: jax.lax.map(afbp_map, sinoshd[i]))(jnp.arange(nproc)) diff --git a/scico/linop/xray/astra.py b/scico/linop/xray/astra.py index 28e6b65a0..4bcef5590 100644 --- a/scico/linop/xray/astra.py +++ b/scico/linop/xray/astra.py @@ -17,7 +17,7 @@ not available. """ -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple, Union import numpy as np @@ -33,12 +33,28 @@ else: raise e +try: + from collections import Iterable # type: ignore +except ImportError: + import collections + + # Monkey patching required because latest astra release uses old module path for Iterable + collections.Iterable = collections.abc.Iterable # type: ignore from scico.typing import Shape from .._linop import LinearOperator +def set_astra_gpu_index(idx: Union[int, Sequence[int]]): + """Set the index/indices of GPU(s) to be used by astra. + + Args: + idx: Index or indices of GPU(s). + """ + astra.set_gpu_index(idx) + + class XRayTransform2D(LinearOperator): r"""2D parallel beam X-ray transform based on the ASTRA toolbox. @@ -108,11 +124,16 @@ def __init__( "Please see the astra documentation for details." ) - dev0 = jax.devices()[0] - if dev0.platform == "cpu" or device == "cpu": - self.device = "cpu" - elif dev0.platform == "gpu" and device in ["gpu", "auto"]: - self.device = "gpu" + if device in ["cpu", "gpu"]: + # If cpu or gpu selected, attempt to comply (no checking to + # confirm that a gpu is available to astra). + self.device = device + elif device == "auto": + # If auto selected, use cpu or gpu depending on the default + # jax device (for simplicity, no checking whether gpu is + # available to astra when one is not available to jax). + dev0 = jax.devices()[0] + self.device = dev0.platform else: raise ValueError(f"Invalid device specified; got {device}.") @@ -280,10 +301,6 @@ def __init__( self.input_shape: tuple = input_shape self.vol_geom = astra.create_vol_geom(input_shape[1], input_shape[2], input_shape[0]) - dev0 = jax.devices()[0] - if dev0.platform == "cpu": - raise ValueError("No CPU algorithm for 3D projection and GPU not available.") - # Wrap our non-jax function to indicate we will supply fwd/rev mode functions self._eval = jax.custom_vjp(self._proj) self._eval.defvjp(lambda x: (self._proj(x), None), lambda _, y: (self._bproj(y),)) # type: ignore diff --git a/scico/optimize/_ladmm.py b/scico/optimize/_ladmm.py index 26b049311..d2e766883 100644 --- a/scico/optimize/_ladmm.py +++ b/scico/optimize/_ladmm.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021-2023 by SCICO Developers +# Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -11,7 +11,7 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import scico.numpy as snp from scico.functional import Functional @@ -149,7 +149,7 @@ def minimizer(self): def objective( self, x: Optional[Union[Array, BlockArray]] = None, - z: Optional[List[Union[Array, BlockArray]]] = None, + z: Optional[Union[Array, BlockArray]] = None, ) -> float: r"""Evaluate the objective function. diff --git a/scico/optimize/_padmm.py b/scico/optimize/_padmm.py index ee3d5c516..37cb4ff2b 100644 --- a/scico/optimize/_padmm.py +++ b/scico/optimize/_padmm.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -11,7 +11,7 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import scico.numpy as snp from scico import cvjp, jvp @@ -144,7 +144,7 @@ def minimizer(self): def objective( self, x: Optional[Union[Array, BlockArray]] = None, - z: Optional[List[Union[Array, BlockArray]]] = None, + z: Optional[Union[Array, BlockArray]] = None, ) -> float: r"""Evaluate the objective function. @@ -289,7 +289,7 @@ def __init__( def norm_primal_residual( self, x: Optional[Union[Array, BlockArray]] = None, - z: Optional[List[Union[Array, BlockArray]]] = None, + z: Optional[Union[Array, BlockArray]] = None, ) -> float: r"""Compute the :math:`\ell_2` norm of the primal residual. @@ -507,7 +507,7 @@ def __init__( def norm_primal_residual( self, x: Optional[Union[Array, BlockArray]] = None, - z: Optional[List[Union[Array, BlockArray]]] = None, + z: Optional[Union[Array, BlockArray]] = None, ) -> float: r"""Compute the :math:`\ell_2` norm of the primal residual. diff --git a/scico/test/linop/xray/test_astra.py b/scico/test/linop/xray/test_astra.py index 3763f377a..6ed0b1efd 100644 --- a/scico/test/linop/xray/test_astra.py +++ b/scico/test/linop/xray/test_astra.py @@ -129,15 +129,6 @@ def test_jit_in_DiagonalStack(): H.T @ snp.zeros(H.output_shape, dtype=snp.float32) -@pytest.mark.skipif(jax.devices()[0].platform != "cpu", reason="checking CPU behavior") -def test_3D_on_CPU(): - x = snp.zeros((4, 5, 6)) - with pytest.raises(ValueError): - A = XRayTransform3D( - x.shape, det_count=[6, 6], det_spacing=[1.0, 1.0], angles=snp.linspace(0, snp.pi, 10) - ) - - @pytest.mark.skipif(jax.devices()[0].platform != "gpu", reason="checking GPU behavior") def test_3D_on_GPU(): x = snp.zeros((4, 5, 6))