diff --git a/src/caked/base.py b/src/caked/base.py index 767bc8a..646ac00 100644 --- a/src/caked/base.py +++ b/src/caked/base.py @@ -29,7 +29,11 @@ def process(self): pass @abstractmethod - def get_loader(self, split_size: float, batch_size: int): + def get_loader( + self, + batch_size: int, + split_size: float | None = None, + ): pass diff --git a/src/caked/dataloader.py b/src/caked/dataloader.py index f66029e..cbdd1a0 100644 --- a/src/caked/dataloader.py +++ b/src/caked/dataloader.py @@ -10,6 +10,7 @@ import numpy as np import torch from scipy.ndimage import zoom +from torch.utils.data import DataLoader, Subset from torchvision import transforms from .base import AbstractDataLoader, AbstractDataset @@ -73,8 +74,45 @@ def load(self, datapath, datatype) -> None: def process(self): return super().process() - def get_loader(self, split_size: float, batch_size: int): - return super().get_loader(split_size, batch_size) + def get_loader(self, batch_size: int, split_size: float | None = None): + if self.training: + if split_size is None: + msg = "Split size must be provided for training. " + raise RuntimeError(msg) + # split into train / val sets + idx = np.random.permutation(len(self.dataset)) + if split_size < 1: + split_size = split_size * 100 + + s = int(np.ceil(len(self.dataset) * int(split_size) / 100)) + if s < 2: + msg = "Train and validation sets must be larger than 1 sample, train: {}, val: {}.".format( + len(idx[:-s]), len(idx[-s:]) + ) + raise RuntimeError(msg) + train_data = Subset(self.dataset, indices=idx[:-s]) + val_data = Subset(self.dataset, indices=idx[-s:]) + + loader_train = DataLoader( + train_data, + batch_size=batch_size, + num_workers=0, + shuffle=True, + ) + loader_val = DataLoader( + val_data, + batch_size=batch_size, + num_workers=0, + shuffle=True, + ) + return loader_val, loader_train + + return DataLoader( + self.dataset, + batch_size=batch_size, + num_workers=0, + shuffle=True, + ) class DiskDataset(AbstractDataset): diff --git a/tests/test_io.py b/tests/test_disk_io.py similarity index 68% rename from tests/test_io.py rename to tests/test_disk_io.py index 355ef3e..d7df678 100644 --- a/tests/test_io.py +++ b/tests/test_disk_io.py @@ -86,3 +86,44 @@ def test_one_image(): test_item_image, test_item_name = test_dataset.__getitem__(1) assert test_item_name in DISK_CLASSES_FULL assert isinstance(test_item_image, torch.Tensor) + + +def test_get_loader_training_false(): + test_loader = DiskDataLoader( + pipeline=DISK_PIPELINE, + classes=DISK_CLASSES_FULL, + dataset_size=DATASET_SIZE_ALL, + training=False, + ) + test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + torch_loader = test_loader.get_loader(batch_size=64) + assert isinstance(torch_loader, torch.utils.data.DataLoader) + + +def test_get_loader_training_true(): + test_loader = DiskDataLoader( + pipeline=DISK_PIPELINE, + classes=DISK_CLASSES_FULL, + dataset_size=DATASET_SIZE_ALL, + training=True, + ) + test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + torch_loader_train, torch_loader_val = test_loader.get_loader( + split_size=0.8, batch_size=64 + ) + assert isinstance(torch_loader_train, torch.utils.data.DataLoader) + assert isinstance(torch_loader_val, torch.utils.data.DataLoader) + + +def test_get_loader_training_fail(): + test_loader = DiskDataLoader( + pipeline=DISK_PIPELINE, + classes=DISK_CLASSES_FULL, + dataset_size=DATASET_SIZE_ALL, + training=True, + ) + test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + with pytest.raises(Exception, match=r".* sets must be larger than .*"): + torch_loader_train, torch_loader_val = test_loader.get_loader( + split_size=1, batch_size=64 + )