Skip to content

Commit

Permalink
Merge pull request #7 from alan-turing-institute/dataloader_builder
Browse files Browse the repository at this point in the history
Adding get_loader function
  • Loading branch information
mooniean authored Feb 12, 2024
2 parents bd47c25 + 8ff5024 commit f45e418
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 3 deletions.
6 changes: 5 additions & 1 deletion src/caked/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def process(self):
pass

@abstractmethod
def get_loader(self, split_size: float, batch_size: int):
def get_loader(
self,
batch_size: int,
split_size: float | None = None,
):
pass


Expand Down
42 changes: 40 additions & 2 deletions src/caked/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import torch
from scipy.ndimage import zoom
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

from .base import AbstractDataLoader, AbstractDataset
Expand Down Expand Up @@ -73,8 +74,45 @@ def load(self, datapath, datatype) -> None:
def process(self):
return super().process()

def get_loader(self, split_size: float, batch_size: int):
return super().get_loader(split_size, batch_size)
def get_loader(self, batch_size: int, split_size: float | None = None):
if self.training:
if split_size is None:
msg = "Split size must be provided for training. "
raise RuntimeError(msg)
# split into train / val sets
idx = np.random.permutation(len(self.dataset))
if split_size < 1:
split_size = split_size * 100

s = int(np.ceil(len(self.dataset) * int(split_size) / 100))
if s < 2:
msg = "Train and validation sets must be larger than 1 sample, train: {}, val: {}.".format(
len(idx[:-s]), len(idx[-s:])
)
raise RuntimeError(msg)
train_data = Subset(self.dataset, indices=idx[:-s])
val_data = Subset(self.dataset, indices=idx[-s:])

loader_train = DataLoader(
train_data,
batch_size=batch_size,
num_workers=0,
shuffle=True,
)
loader_val = DataLoader(
val_data,
batch_size=batch_size,
num_workers=0,
shuffle=True,
)
return loader_val, loader_train

return DataLoader(
self.dataset,
batch_size=batch_size,
num_workers=0,
shuffle=True,
)


class DiskDataset(AbstractDataset):
Expand Down
41 changes: 41 additions & 0 deletions tests/test_io.py → tests/test_disk_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,44 @@ def test_one_image():
test_item_image, test_item_name = test_dataset.__getitem__(1)
assert test_item_name in DISK_CLASSES_FULL
assert isinstance(test_item_image, torch.Tensor)


def test_get_loader_training_false():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL,
dataset_size=DATASET_SIZE_ALL,
training=False,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
torch_loader = test_loader.get_loader(batch_size=64)
assert isinstance(torch_loader, torch.utils.data.DataLoader)


def test_get_loader_training_true():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL,
dataset_size=DATASET_SIZE_ALL,
training=True,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
torch_loader_train, torch_loader_val = test_loader.get_loader(
split_size=0.8, batch_size=64
)
assert isinstance(torch_loader_train, torch.utils.data.DataLoader)
assert isinstance(torch_loader_val, torch.utils.data.DataLoader)


def test_get_loader_training_fail():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL,
dataset_size=DATASET_SIZE_ALL,
training=True,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
with pytest.raises(Exception, match=r".* sets must be larger than .*"):
torch_loader_train, torch_loader_val = test_loader.get_loader(
split_size=1, batch_size=64
)

0 comments on commit f45e418

Please sign in to comment.