Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connected processing to loading the dataset. #11

Merged
merged 6 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/caked/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def load(self, datapath, datatype):
pass

@abstractmethod
def process(self):
def process(self, paths, datatype):
pass

@abstractmethod
Expand Down
41 changes: 35 additions & 6 deletions src/caked/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

from .base import AbstractDataLoader, AbstractDataset

np.random.seed(42)
TRANSFORM_OPTIONS = ["normalise", "gaussianblur", "shiftmin"]


class DiskDataLoader(AbstractDataLoader):
def __init__(
Expand All @@ -24,11 +27,15 @@ def __init__(
training: bool = True,
classes: list[str] | None = None,
pipeline: str = "disk",
transformations: str | None = None,
) -> None:
self.dataset_size = dataset_size
self.save_to_disk = save_to_disk
self.training = training
self.pipeline = pipeline
self.transformations = transformations
self.debug = False

if classes is None:
self.classes = []
else:
Expand All @@ -37,7 +44,8 @@ def __init__(
def load(self, datapath, datatype) -> None:
paths = [f for f in os.listdir(datapath) if "." + datatype in f]

random.shuffle(paths)
if not self.debug:
random.shuffle(paths)

# ids right now depend on the data being saved with a certain format (id in the first part of the name, separated by _)
# TODO: make this more general/document in the README
Expand Down Expand Up @@ -69,10 +77,32 @@ def load(self, datapath, datatype) -> None:
if self.dataset_size is not None:
paths = paths[: self.dataset_size]

self.dataset = DiskDataset(paths=paths, datatype=datatype)
if self.transformations is None:
self.dataset = DiskDataset(paths=paths, datatype=datatype)
else:
self.dataset = self.process(paths=paths, datatype=datatype)

def process(self):
return super().process()
def process(self, paths: list[str], datatype: str):
if self.transformations is None:
msg = "No processing to do as no transformations were provided."
raise RuntimeError(msg)
transforms = self.transformations.split(",")
rescale = 0
for i in transforms:
if i.startswith("rescale"):
transforms.remove(i)
rescale = int(i.split("=")[-1])

normalise, gaussianblur, shiftmin = np.in1d(TRANSFORM_OPTIONS, transforms)

return DiskDataset(
paths=paths,
datatype=datatype,
rescale=rescale,
normalise=normalise,
gaussianblur=gaussianblur,
shiftmin=shiftmin,
)

def get_loader(self, batch_size: int, split_size: float | None = None):
if self.training:
Expand Down Expand Up @@ -120,7 +150,7 @@ def __init__(
self,
paths: list[str],
datatype: str = "npy",
rescale: bool = False,
rescale: int = 0,
shiftmin: bool = False,
gaussianblur: bool = False,
normalise: bool = False,
Expand All @@ -130,7 +160,6 @@ def __init__(
self.rescale = rescale
self.normalise = normalise
self.gaussianblur = gaussianblur
self.rescale = rescale
self.transform = input_transform
self.datatype = datatype
self.shiftmin = shiftmin
Expand Down
186 changes: 165 additions & 21 deletions tests/test_disk_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,39 @@

from pathlib import Path

import numpy as np
import pytest
import torch
from tests import testdata_mrc
from tests import testdata_mrc, testdata_npy

from caked.dataloader import DiskDataLoader, DiskDataset

ORIG_DIR = Path.cwd()
TEST_DATA_MRC = Path(testdata_mrc.__file__).parent
TEST_DATA_NPY = Path(testdata_npy.__file__).parent
DISK_PIPELINE = "disk"
DATASET_SIZE_ALL = None
DATASET_SIZE_SOME = 3
DISK_CLASSES_FULL = ["1b23", "1dfo", "1dkg", "1e3p"]
DISK_CLASSES_SOME = ["1b23", "1dkg"]
DISK_CLASSES_MISSING = ["2b3a", "1b23"]
DISK_CLASSES_FULL_MRC = ["1b23", "1dfo", "1dkg", "1e3p"]
DISK_CLASSES_SOME_MRC = ["1b23", "1dkg"]
DISK_CLASSES_MISSING_MRC = ["2b3a", "1b23"]
DISK_CLASSES_FULL_NPY = ["2", "5", "a", "d", "e", "i", "j", "l", "s", "u", "v", "x"]
DISK_CLASSES_SOME_NPY = ["2", "5"]
DISK_CLASSES_MISSING_NPY = ["2", "a", "1"]

DISK_CLASSES_NONE = None
DATATYPE_MRC = "mrc"
DATATYPE_NPY = "npy"
TRANSFORM_ALL = "normalise,gaussianblur,shiftmin"
TRANSFORM_ALL_RESCALE = "normalise,gaussianblur,shiftmin,rescale=0"
TRANSFORM_SOME = "normalise,gaussianblur"
TRANSFORM_RESCALE = "rescale=32"


def test_class_instantiation():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_SOME,
classes=DISK_CLASSES_SOME_MRC,
dataset_size=DATASET_SIZE_SOME,
save_to_disk=False,
training=True,
Expand All @@ -32,8 +43,13 @@ def test_class_instantiation():
assert test_loader.pipeline == DISK_PIPELINE


def test_dataset_instantiation():
test_dataset = DiskDataset(paths=["test"])
def test_dataset_instantiation_mrc():
test_dataset = DiskDataset(paths=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert isinstance(test_dataset, DiskDataset)


def test_dataset_instantiation_npy():
test_dataset = DiskDataset(paths=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert isinstance(test_dataset, DiskDataset)


Expand All @@ -43,34 +59,50 @@ def test_load_dataset_no_classes():
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert isinstance(test_loader.dataset, DiskDataset)
assert len(test_loader.classes) == len(DISK_CLASSES_FULL)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL))
assert len(test_loader.classes) == len(DISK_CLASSES_FULL_MRC)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL_MRC))


def test_load_dataset_all_classes():
def test_load_dataset_all_classes_mrc():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE, classes=DISK_CLASSES_FULL, dataset_size=DATASET_SIZE_ALL
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert isinstance(test_loader.dataset, DiskDataset)
assert len(test_loader.classes) == len(DISK_CLASSES_FULL)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL))
assert len(test_loader.classes) == len(DISK_CLASSES_FULL_MRC)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL_MRC))


def test_load_dataset_all_classes_npy():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_NPY,
dataset_size=DATASET_SIZE_ALL,
)
test_loader.load(datapath=TEST_DATA_NPY, datatype=DATATYPE_NPY)
assert isinstance(test_loader.dataset, DiskDataset)
assert len(test_loader.classes) == len(DISK_CLASSES_FULL_NPY)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL_NPY))


def test_load_dataset_some_classes():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE, classes=DISK_CLASSES_SOME, dataset_size=DATASET_SIZE_ALL
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_SOME_MRC,
dataset_size=DATASET_SIZE_ALL,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert isinstance(test_loader.dataset, DiskDataset)
assert len(test_loader.classes) == len(DISK_CLASSES_SOME)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_SOME))
assert len(test_loader.classes) == len(DISK_CLASSES_SOME_MRC)
assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_SOME_MRC))


def test_load_dataset_missing_class():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_MISSING,
classes=DISK_CLASSES_MISSING_MRC,
dataset_size=DATASET_SIZE_ALL,
)
with pytest.raises(Exception, match=r".*Missing classes: .*"):
Expand All @@ -84,14 +116,14 @@ def test_one_image():
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
test_dataset = test_loader.dataset
test_item_image, test_item_name = test_dataset.__getitem__(1)
assert test_item_name in DISK_CLASSES_FULL
assert test_item_name in DISK_CLASSES_FULL_MRC
assert isinstance(test_item_image, torch.Tensor)


def test_get_loader_training_false():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
training=False,
)
Expand All @@ -103,7 +135,7 @@ def test_get_loader_training_false():
def test_get_loader_training_true():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
training=True,
)
Expand All @@ -118,7 +150,7 @@ def test_get_loader_training_true():
def test_get_loader_training_fail():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
training=True,
)
Expand All @@ -127,3 +159,115 @@ def test_get_loader_training_fail():
torch_loader_train, torch_loader_val = test_loader.get_loader(
split_size=1, batch_size=64
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it will be good to test the data loader returning what we want, I think here we have only tested that the class variables are correct (which is great!) but we won't catch things like the data loader not returning the right size of data, or selecting correctly the labels.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a solid point! I checked it but realised it was only on a local function, I'll make that into assertions here


def test_processing_data_all_transforms():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
training=True,
transformations=TRANSFORM_ALL,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert test_loader.dataset.normalise
assert test_loader.dataset.shiftmin
assert test_loader.dataset.gaussianblur
image, label = next(iter(test_loader.dataset))
image = np.squeeze(image.cpu().numpy())
assert len(image[0]) == len(image[1]) == len(image[2])
assert label in DISK_CLASSES_FULL_MRC


def test_processing_data_some_transforms_npy():
test_loader_transf = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_NPY,
dataset_size=DATASET_SIZE_ALL,
training=True,
transformations=TRANSFORM_SOME,
)
test_loader_none = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_NPY,
dataset_size=DATASET_SIZE_ALL,
training=True,
)
test_loader_none.load(datapath=TEST_DATA_NPY, datatype=DATATYPE_NPY)
test_loader_transf.load(datapath=TEST_DATA_NPY, datatype=DATATYPE_NPY)
assert test_loader_transf.dataset.normalise
assert not test_loader_transf.dataset.shiftmin
assert test_loader_transf.dataset.gaussianblur
image_none, label_none = next(iter(test_loader_none.dataset))
image_none = np.squeeze(image_none.cpu().numpy())
assert len(image_none[0]) == len(image_none[1])
assert label_none in DISK_CLASSES_FULL_NPY
image_transf, label_transf = next(iter(test_loader_transf.dataset))
image_transf = np.squeeze(image_transf.cpu().numpy())
assert len(image_transf[0]) == len(image_transf[1])
assert label_transf in DISK_CLASSES_FULL_NPY
assert len(image_none[0]) == len(image_transf[0])
assert len(image_none[1]) == len(image_transf[1])


def test_processing_data_rescale():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
training=True,
transformations=TRANSFORM_ALL_RESCALE,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert test_loader.dataset.normalise
assert test_loader.dataset.shiftmin
assert test_loader.dataset.gaussianblur
assert test_loader.dataset.rescale == 0
image, label = next(iter(test_loader.dataset))
image = np.squeeze(image.cpu().numpy())
assert len(image[0]) == len(image[1]) == len(image[2])
assert label in DISK_CLASSES_FULL_MRC

test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
training=True,
transformations=TRANSFORM_RESCALE,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert not test_loader.dataset.normalise
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could add a test to make sure the transformation is happening, e.g. that the output dataset (for an example data point) is different if you pass a transformation to passing none.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Latest commit addresses this!! I hope that's ok, made a debug flag that doesn't shuffle the paths so we can test the processing is happening.

assert not test_loader.dataset.shiftmin
assert not test_loader.dataset.gaussianblur
assert test_loader.dataset.rescale == 32
image, label = next(iter(test_loader.dataset))
image = np.squeeze(image.cpu().numpy())
assert len(image[0]) == len(image[1]) == len(image[2])
assert label in DISK_CLASSES_FULL_MRC


def test_processing_after_load():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
training=False,
)
test_loader.debug = True
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert test_loader.transformations is None
assert not test_loader.dataset.normalise
assert not test_loader.dataset.shiftmin
assert not test_loader.dataset.gaussianblur
test_loader.transformations = TRANSFORM_ALL_RESCALE
pre_dataset = test_loader.dataset
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
post_dataset = test_loader.dataset
assert test_loader.dataset.normalise
assert test_loader.dataset.shiftmin
assert test_loader.dataset.gaussianblur
assert len(post_dataset) == len(pre_dataset)
pre_image, pre_label = next(iter(pre_dataset))
post_image, post_label = next(iter(post_dataset))
assert pre_label == post_label
assert not torch.equal(pre_image, post_image)
Binary file added tests/testdata_npy/2_0_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_1_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_2_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_3_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_4_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_5_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_6_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_8_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/2_9_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_0_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_10_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_1_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_2_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_3_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_4_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_5_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_6_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_7_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/5_9_img.npy
Binary file not shown.
Empty file added tests/testdata_npy/__init__.py
Empty file.
Binary file added tests/testdata_npy/a_0_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_10_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_1_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_2_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_3_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_4_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_5_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_6_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_7_img.npy
Binary file not shown.
Binary file added tests/testdata_npy/a_8_img.npy
Binary file not shown.
Loading
Loading