Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connected processing to loading the dataset. #11

Merged
merged 6 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/caked/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def load(self, datapath, datatype):
pass

@abstractmethod
def process(self):
def process(self, paths, datatype):
pass

@abstractmethod
Expand Down
29 changes: 25 additions & 4 deletions src/caked/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

from .base import AbstractDataLoader, AbstractDataset

np.random.seed(42)
TRANSFORM_OPTIONS = ["rescale", "normalise", "gaussianblur", "shiftmin"]


class DiskDataLoader(AbstractDataLoader):
def __init__(
Expand All @@ -24,11 +27,13 @@ def __init__(
training: bool = True,
classes: list[str] | None = None,
pipeline: str = "disk",
transformations: str | None = None,
) -> None:
self.dataset_size = dataset_size
self.save_to_disk = save_to_disk
self.training = training
self.pipeline = pipeline
self.transformations = transformations
if classes is None:
self.classes = []
else:
Expand Down Expand Up @@ -69,10 +74,27 @@ def load(self, datapath, datatype) -> None:
if self.dataset_size is not None:
paths = paths[: self.dataset_size]

self.dataset = DiskDataset(paths=paths, datatype=datatype)
if self.transformations is None:
self.dataset = DiskDataset(paths=paths, datatype=datatype)
else:
self.dataset = self.process(paths=paths, datatype=datatype)

def process(self):
return super().process()
def process(self, paths: list[str], datatype: str):
if self.transformations is None:
msg = "No processing to do as no transformations were provided."
raise RuntimeError(msg)
transforms = self.transformations.split(",")
rescale, normalise, gaussianblur, shiftmin = np.in1d(
TRANSFORM_OPTIONS, transforms
)
return DiskDataset(
paths=paths,
datatype=datatype,
rescale=rescale,
normalise=normalise,
gaussianblur=gaussianblur,
shiftmin=shiftmin,
)

def get_loader(self, batch_size: int, split_size: float | None = None):
if self.training:
Expand Down Expand Up @@ -130,7 +152,6 @@ def __init__(
self.rescale = rescale
self.normalise = normalise
self.gaussianblur = gaussianblur
self.rescale = rescale
self.transform = input_transform
self.datatype = datatype
self.shiftmin = shiftmin
Expand Down
32 changes: 32 additions & 0 deletions tests/test_disk_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
DISK_CLASSES_MISSING = ["2b3a", "1b23"]
DISK_CLASSES_NONE = None
DATATYPE_MRC = "mrc"
TRANSFORM_ALL = "rescale,normalise,gaussianblur,shiftmin"
TRANSFORM_SOME = "rescale,gaussianblur"


def test_class_instantiation():
Expand Down Expand Up @@ -127,3 +129,33 @@ def test_get_loader_training_fail():
torch_loader_train, torch_loader_val = test_loader.get_loader(
split_size=1, batch_size=64
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it will be good to test the data loader returning what we want, I think here we have only tested that the class variables are correct (which is great!) but we won't catch things like the data loader not returning the right size of data, or selecting correctly the labels.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a solid point! I checked it but realised it was only on a local function, I'll make that into assertions here


def test_processing_data_all_transforms():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL,
dataset_size=DATASET_SIZE_ALL,
training=True,
transformations=TRANSFORM_ALL,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert test_loader.dataset.normalise
assert test_loader.dataset.shiftmin
assert test_loader.dataset.gaussianblur
assert test_loader.dataset.rescale


def test_processing_data_some_transforms():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL,
dataset_size=DATASET_SIZE_ALL,
training=True,
transformations=TRANSFORM_SOME,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
assert not test_loader.dataset.normalise
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could add a test to make sure the transformation is happening, e.g. that the output dataset (for an example data point) is different if you pass a transformation to passing none.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Latest commit addresses this!! I hope that's ok, made a debug flag that doesn't shuffle the paths so we can test the processing is happening.

assert not test_loader.dataset.shiftmin
assert test_loader.dataset.gaussianblur
assert test_loader.dataset.rescale
Loading