diff --git a/.github/workflows/pytest_macos.yml b/.github/workflows/pytest_macos.yml index 0b097029c..82cf74b95 100644 --- a/.github/workflows/pytest_macos.yml +++ b/.github/workflows/pytest_macos.yml @@ -59,11 +59,13 @@ jobs: pip install pytest-split pip install -r requirements.txt pip install -r dev_requirements.txt + mamba install -c conda-forge svmbir>=0.3.3 mamba install -c astra-toolbox astra-toolbox mamba install -c conda-forge pyyaml pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version + pip install bm3d>=4.0.0 pip install bm4d>=4.0.0 - pip install "ray[tune]>=2.0.0" + pip install "ray[tune]>=2.5.0" pip install hyperopt # Install package to be tested - name: Install package to be tested diff --git a/.github/workflows/pytest_ubuntu.yml b/.github/workflows/pytest_ubuntu.yml index db24830e5..f60c9d4ba 100644 --- a/.github/workflows/pytest_ubuntu.yml +++ b/.github/workflows/pytest_ubuntu.yml @@ -62,12 +62,13 @@ jobs: pip install pytest-split pip install -r requirements.txt pip install -r dev_requirements.txt + mamba install -c conda-forge svmbir>=0.3.3 mamba install -c astra-toolbox astra-toolbox mamba install -c conda-forge pyyaml pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version pip install bm3d>=4.0.0 pip install bm4d>=4.2.2 - pip install "ray[tune]>=2.0.0" + pip install "ray[tune]>=2.5.0" pip install hyperopt # Install package to be tested - name: Install package to be tested diff --git a/data b/data index a33ca716e..d5aff70f9 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit a33ca716e42ba7593d6120752053869fce8b1abb +Subproject commit d5aff70f95d33abf72e785fb945cc556db16cd12 diff --git a/dev_requirements.txt b/dev_requirements.txt index aa21af84b..213aa591e 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,3 +1,4 @@ +-r requirements.txt pylint pytest>=7.3.0 pytest-runner diff --git a/examples/examples_requirements.txt b/examples/examples_requirements.txt index 125852f90..15d0a2a1e 100644 --- a/examples/examples_requirements.txt +++ b/examples/examples_requirements.txt @@ -1,6 +1,8 @@ -r ../requirements.txt -astra-toolbox +tifffile colour_demosaicing +svmbir>=0.3.3 +astra-toolbox xdesign>=0.5.5 ray[tune,train]>=2.5.0 hyperopt diff --git a/requirements.txt b/requirements.txt index 97f0263e5..76722063a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ +typing_extensions numpy>=1.20.0 scipy>=1.6.0 -tifffile imageio>=2.17 matplotlib jaxlib>=0.4.3,<=0.4.23 jax>=0.4.3,<=0.4.23 +orbax-checkpoint flax>=0.6.1,<=0.7.5 -svmbir>=0.3.3 pyabel>=0.9.0 diff --git a/scico/flax/train/checkpoints.py b/scico/flax/train/checkpoints.py index 6ef233d80..dac4fd872 100644 --- a/scico/flax/train/checkpoints.py +++ b/scico/flax/train/checkpoints.py @@ -1,17 +1,19 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022 by SCICO Developers +# Copyright (C) 2022-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. """Utilities for checkpointing Flax models.""" + + from pathlib import Path from typing import Union import jax -import orbax +import orbax.checkpoint from flax.training import orbax_utils @@ -29,13 +31,14 @@ def checkpoint_restore( parameters. workdir: Checkpoint file or directory of checkpoints to restore from. - ok_no_ckpt: Flag to indicate if a checkpoint is expected. Default: - False, a checkpoint is expected and an error is generated. + ok_no_ckpt: Flag to indicate if a checkpoint is expected. If + ``False``, an error is generated if a checkpoint is not + found. Returns: - A restored Flax train state updated from checkpoint file is returned. - If no checkpoint files are present and checkpoints are not strictly - expected it returns the passed-in `state` unchanged. + A restored Flax train state updated from checkpoint file is + returned. If no checkpoint files are present and checkpoints are + not strictly expected it returns the passed-in `state` unchanged. Raises: FileNotFoundError: If a checkpoint is expected and is not found. @@ -68,7 +71,7 @@ def checkpoint_save(state: TrainState, config: ConfigDict, workdir: Union[str, P state: Flax train state which includes model and optimiser parameters. config: Python dictionary including model train configuration. - workdir: str or pathlib-like path to store checkpoint files in. + workdir: Path in which to store checkpoint files. """ if jax.process_index() == 0: orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() diff --git a/setup.py b/setup.py index 30e37e751..b7941b5ec 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ import site import sys -from setuptools import find_packages, setup +from setuptools import find_namespace_packages, setup # Import module scico._version without executing __init__.py spec = importlib.util.spec_from_file_location("_version", os.path.join("scico", "_version.py")) @@ -20,7 +20,10 @@ name = "scico" version = package_version() -packages = find_packages() +# Add argument exclude=["test", "test.*"] to exclude test subpackage +packages = find_namespace_packages(where="scico") +packages = [f"scico.{m}" for m in packages] + longdesc = """ SCICO is a Python package for solving the inverse problems that arise in scientific imaging applications. Its primary focus is providing methods for solving ill-posed inverse problems by using an appropriate prior model of the reconstruction space. SCICO includes a growing suite of operators, cost functionals, regularizers, and optimization routines that may be combined to solve a wide range of problems, and is designed so that it is easy to add new building blocks. SCICO is built on top of JAX, which provides features such as automatic gradient calculation and GPU acceleration.