Skip to content

Commit

Permalink
Merge pull request #11 from alan-turing-institute/disk_processing
Browse files Browse the repository at this point in the history
Connected processing to loading the dataset, added disk IO tests.
  • Loading branch information
mooniean authored Feb 22, 2024
2 parents f45e418 + a4b8b31 commit c6b4672
Show file tree
Hide file tree
Showing 138 changed files with 239 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/caked/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def load(self, datapath, datatype):
pass

@abstractmethod
def process(self):
def process(self, paths, datatype):
pass

@abstractmethod
Expand Down
41 changes: 35 additions & 6 deletions src/caked/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

from .base import AbstractDataLoader, AbstractDataset

np.random.seed(42)
TRANSFORM_OPTIONS = ["normalise", "gaussianblur", "shiftmin"]


class DiskDataLoader(AbstractDataLoader):
def __init__(
Expand All @@ -24,11 +27,15 @@ def __init__(
training: bool = True,
classes: list[str] | None = None,
pipeline: str = "disk",
transformations: str | None = None,
) -> None:
self.dataset_size = dataset_size
self.save_to_disk = save_to_disk
self.training = training
self.pipeline = pipeline
self.transformations = transformations
self.debug = False

if classes is None:
self.classes = []
else:
Expand All @@ -37,7 +44,8 @@ def __init__(
def load(self, datapath, datatype) -> None:
paths = [f for f in os.listdir(datapath) if "." + datatype in f]

random.shuffle(paths)
if not self.debug:
random.shuffle(paths)

# ids right now depend on the data being saved with a certain format (id in the first part of the name, separated by _)
# TODO: make this more general/document in the README
Expand Down Expand Up @@ -69,10 +77,32 @@ def load(self, datapath, datatype) -> None:
if self.dataset_size is not None:
paths = paths[: self.dataset_size]

self.dataset = DiskDataset(paths=paths, datatype=datatype)
if self.transformations is None:
self.dataset = DiskDataset(paths=paths, datatype=datatype)
else:
self.dataset = self.process(paths=paths, datatype=datatype)

def process(self):
return super().process()
def process(self, paths: list[str], datatype: str):
if self.transformations is None:
msg = "No processing to do as no transformations were provided."
raise RuntimeError(msg)
transforms = self.transformations.split(",")
rescale = 0
for i in transforms:
if i.startswith("rescale"):
transforms.remove(i)
rescale = int(i.split("=")[-1])

normalise, gaussianblur, shiftmin = np.in1d(TRANSFORM_OPTIONS, transforms)

return DiskDataset(
paths=paths,
datatype=datatype,
rescale=rescale,
normalise=normalise,
gaussianblur=gaussianblur,
shiftmin=shiftmin,
)

def get_loader(self, batch_size: int, split_size: float | None = None):
if self.training:
Expand Down Expand Up @@ -120,7 +150,7 @@ def __init__(
self,
paths: list[str],
datatype: str = "npy",
rescale: bool = False,
rescale: int = 0,
shiftmin: bool = False,
gaussianblur: bool = False,
normalise: bool = False,
Expand All @@ -130,7 +160,6 @@ def __init__(
self.rescale = rescale
self.normalise = normalise
self.gaussianblur = gaussianblur
self.rescale = rescale
self.transform = input_transform
self.datatype = datatype
self.shiftmin = shiftmin
Expand Down
186 changes: 165 additions & 21 deletions tests/test_disk_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,39 @@

from pathlib import Path

import numpy as np
import pytest
import torch
from tests import testdata_mrc
from tests import testdata_mrc, testdata_npy

from caked.dataloader import DiskDataLoader, DiskDataset

ORIG_DIR = Path.cwd()
TEST_DATA_MRC = Path(testdata_mrc.__file__).parent
TEST_DATA_NPY = Path(testdata_npy.__file__).parent
DISK_PIPELINE = "disk"
DATASET_SIZE_ALL = None
DATASET_SIZE_SOME = 3
DISK_CLASSES_FULL = ["1b23", "1dfo", "1dkg", "1e3p"]
DISK_CLASSES_SOME = ["1b23", "1dkg"]
DISK_CLASSES_MISSING = ["2b3a", "1b23"]
DISK_CLASSES_FULL_MRC = ["1b23", "1dfo", "1dkg", "1e3p"]
DISK_CLASSES_SOME_MRC = ["1b23", "1dkg"]
DISK_CLASSES_MISSING_MRC = ["2b3a", "1b23"]
DISK_CLASSES_FULL_NPY = ["2", "5", "a", "d", "e", "i", "j", "l", "s", "u", "v", "x"]
DISK_CLASSES_SOME_NPY = ["2", "5"]
DISK_CLASSES_MISSING_NPY = ["2", "a", "1"]

DISK_CLASSES_NONE = None
DATATYPE_MRC = "mrc"
DATATYPE_NPY = "npy"
TRANSFORM_ALL = "normalise,gaussianblur,shiftmin"
TRANSFORM_ALL_RESCALE = "normalise,gaussianblur,shiftmin,rescale=0"
TRANSFORM_SOME = "normalise,gaussianblur"
TRANSFORM_RESCALE = "rescale=32"


def test_class_instantiation():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_SOME,
classes=DISK_CLASSES_SOME_MRC,
dataset_size=DATASET_SIZE_SOME,
save_to_disk=False,
training=True,
Expand All @@ -32,8 +43,13 @@ def test_class_instantiation():
assert test_loader.pipeline == DISK_PIPELINE


def test_dataset_instantiation():
test_dataset = DiskDataset(paths=["test"])
def test_dataset_instantiation_mrc():
test_dataset = DiskDataset(paths=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert isinstance(test_dataset, DiskDataset)


def test_dataset_instantiation_npy():
test_dataset = DiskDataset(paths=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert isinstance(test_dataset, DiskDataset)


Expand All @@ -43,34 +59,50 @@ def test_load_dataset_no_classes():
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert isinstance(test_loader.dataset, DiskDataset)
assert len(test_loader.classes) == len(DISK_CLASSES_FULL)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL))
assert len(test_loader.classes) == len(DISK_CLASSES_FULL_MRC)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL_MRC))


def test_load_dataset_all_classes():
def test_load_dataset_all_classes_mrc():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE, classes=DISK_CLASSES_FULL, dataset_size=DATASET_SIZE_ALL
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert isinstance(test_loader.dataset, DiskDataset)
assert len(test_loader.classes) == len(DISK_CLASSES_FULL)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL))
assert len(test_loader.classes) == len(DISK_CLASSES_FULL_MRC)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL_MRC))


def test_load_dataset_all_classes_npy():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_NPY,
dataset_size=DATASET_SIZE_ALL,
)
test_loader.load(datapath=TEST_DATA_NPY, datatype=DATATYPE_NPY)
assert isinstance(test_loader.dataset, DiskDataset)
assert len(test_loader.classes) == len(DISK_CLASSES_FULL_NPY)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL_NPY))


def test_load_dataset_some_classes():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE, classes=DISK_CLASSES_SOME, dataset_size=DATASET_SIZE_ALL
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_SOME_MRC,
dataset_size=DATASET_SIZE_ALL,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert isinstance(test_loader.dataset, DiskDataset)
assert len(test_loader.classes) == len(DISK_CLASSES_SOME)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_SOME))
assert len(test_loader.classes) == len(DISK_CLASSES_SOME_MRC)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_SOME_MRC))


def test_load_dataset_missing_class():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_MISSING,
classes=DISK_CLASSES_MISSING_MRC,
dataset_size=DATASET_SIZE_ALL,
)
with pytest.raises(Exception, match=r".*Missing classes: .*"):
Expand All @@ -84,14 +116,14 @@ def test_one_image():
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
test_dataset = test_loader.dataset
test_item_image, test_item_name = test_dataset.__getitem__(1)
assert test_item_name in DISK_CLASSES_FULL
assert test_item_name in DISK_CLASSES_FULL_MRC
assert isinstance(test_item_image, torch.Tensor)


def test_get_loader_training_false():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
training=False,
)
Expand All @@ -103,7 +135,7 @@ def test_get_loader_training_false():
def test_get_loader_training_true():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
training=True,
)
Expand All @@ -118,7 +150,7 @@ def test_get_loader_training_true():
def test_get_loader_training_fail():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
training=True,
)
Expand All @@ -127,3 +159,115 @@ def test_get_loader_training_fail():
torch_loader_train, torch_loader_val = test_loader.get_loader(
split_size=1, batch_size=64
)


def test_processing_data_all_transforms():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
training=True,
transformations=TRANSFORM_ALL,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert test_loader.dataset.normalise
assert test_loader.dataset.shiftmin
assert test_loader.dataset.gaussianblur
image, label = next(iter(test_loader.dataset))
image = np.squeeze(image.cpu().numpy())
assert len(image[0]) == len(image[1]) == len(image[2])
assert label in DISK_CLASSES_FULL_MRC


def test_processing_data_some_transforms_npy():
test_loader_transf = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_NPY,
dataset_size=DATASET_SIZE_ALL,
training=True,
transformations=TRANSFORM_SOME,
)
test_loader_none = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_NPY,
dataset_size=DATASET_SIZE_ALL,
training=True,
)
test_loader_none.load(datapath=TEST_DATA_NPY, datatype=DATATYPE_NPY)
test_loader_transf.load(datapath=TEST_DATA_NPY, datatype=DATATYPE_NPY)
assert test_loader_transf.dataset.normalise
assert not test_loader_transf.dataset.shiftmin
assert test_loader_transf.dataset.gaussianblur
image_none, label_none = next(iter(test_loader_none.dataset))
image_none = np.squeeze(image_none.cpu().numpy())
assert len(image_none[0]) == len(image_none[1])
assert label_none in DISK_CLASSES_FULL_NPY
image_transf, label_transf = next(iter(test_loader_transf.dataset))
image_transf = np.squeeze(image_transf.cpu().numpy())
assert len(image_transf[0]) == len(image_transf[1])
assert label_transf in DISK_CLASSES_FULL_NPY
assert len(image_none[0]) == len(image_transf[0])
assert len(image_none[1]) == len(image_transf[1])


def test_processing_data_rescale():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
training=True,
transformations=TRANSFORM_ALL_RESCALE,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert test_loader.dataset.normalise
assert test_loader.dataset.shiftmin
assert test_loader.dataset.gaussianblur
assert test_loader.dataset.rescale == 0
image, label = next(iter(test_loader.dataset))
image = np.squeeze(image.cpu().numpy())
assert len(image[0]) == len(image[1]) == len(image[2])
assert label in DISK_CLASSES_FULL_MRC

test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
training=True,
transformations=TRANSFORM_RESCALE,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert not test_loader.dataset.normalise
assert not test_loader.dataset.shiftmin
assert not test_loader.dataset.gaussianblur
assert test_loader.dataset.rescale == 32
image, label = next(iter(test_loader.dataset))
image = np.squeeze(image.cpu().numpy())
assert len(image[0]) == len(image[1]) == len(image[2])
assert label in DISK_CLASSES_FULL_MRC


def test_processing_after_load():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
training=False,
)
test_loader.debug = True
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert test_loader.transformations is None
assert not test_loader.dataset.normalise
assert not test_loader.dataset.shiftmin
assert not test_loader.dataset.gaussianblur
test_loader.transformations = TRANSFORM_ALL_RESCALE
pre_dataset = test_loader.dataset
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
post_dataset = test_loader.dataset
assert test_loader.dataset.normalise
assert test_loader.dataset.shiftmin
assert test_loader.dataset.gaussianblur
assert len(post_dataset) == len(pre_dataset)
pre_image, pre_label = next(iter(pre_dataset))
post_image, post_label = next(iter(post_dataset))
assert pre_label == post_label
assert not torch.equal(pre_image, post_image)
Binary file added tests/testdata_npy/2_0_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_1_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_2_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_3_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_4_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_5_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_6_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_8_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_9_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_0_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_10_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_1_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_2_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_3_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_4_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_5_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_6_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_7_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_9_img.npy
Binary file not shown.
Empty file added tests/testdata_npy/__init__.py
Empty file.
Binary file added tests/testdata_npy/a_0_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_10_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_1_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_2_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_3_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_4_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_5_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_6_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_7_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_8_img.npy
Binary file not shown.
Loading

0 comments on commit c6b4672

Please sign in to comment.