From 48bfe1d68c2e7580a7a477fd67917bbeab9fa4cd Mon Sep 17 00:00:00 2001 From: mooniean Date: Fri, 9 Feb 2024 10:55:43 +0000 Subject: [PATCH 1/2] added dataloader functionality, get_loader function, changed the name of the test file --- src/caked/base.py | 6 +++- src/caked/dataloader.py | 43 +++++++++++++++++++++++++-- tests/{test_io.py => test_disk_io.py} | 27 +++++++++++++++++ 3 files changed, 73 insertions(+), 3 deletions(-) rename tests/{test_io.py => test_disk_io.py} (76%) diff --git a/src/caked/base.py b/src/caked/base.py index 767bc8a..646ac00 100644 --- a/src/caked/base.py +++ b/src/caked/base.py @@ -29,7 +29,11 @@ def process(self): pass @abstractmethod - def get_loader(self, split_size: float, batch_size: int): + def get_loader( + self, + batch_size: int, + split_size: float | None = None, + ): pass diff --git a/src/caked/dataloader.py b/src/caked/dataloader.py index f66029e..94e4590 100644 --- a/src/caked/dataloader.py +++ b/src/caked/dataloader.py @@ -10,6 +10,7 @@ import numpy as np import torch from scipy.ndimage import zoom +from torch.utils.data import DataLoader, Subset from torchvision import transforms from .base import AbstractDataLoader, AbstractDataset @@ -73,8 +74,46 @@ def load(self, datapath, datatype) -> None: def process(self): return super().process() - def get_loader(self, split_size: float, batch_size: int): - return super().get_loader(split_size, batch_size) + def get_loader(self, batch_size: int, split_size: float | None = None): + if self.training: + if split_size is None: + msg = "Split size must be provided for training. " + raise RuntimeError(msg) + # split into train / val sets + idx = np.random.permutation(len(self.dataset)) + if split_size < 1: + split_size = split_size * 100 + + s = int(np.ceil(len(self.dataset) * int(split_size) / 100)) + if s < 2: + msg = "Train and validation sets must be larger than 1 sample, train: {}, val: {}.".format( + len(idx[:-s]), len(idx[-s:]) + ) + raise RuntimeError(msg) + train_data = Subset(self.dataset, indices=idx[:-s]) + val_data = Subset(self.dataset, indices=idx[-s:]) + + self.loader_train = DataLoader( + train_data, + batch_size=batch_size, + num_workers=0, + shuffle=True, + ) + self.loader_val = DataLoader( + val_data, + batch_size=batch_size, + num_workers=0, + shuffle=True, + ) + return self.loader_val, self.loader_train + + self.loader = DataLoader( + self.dataset, + batch_size=batch_size, + num_workers=0, + shuffle=True, + ) + return self.loader class DiskDataset(AbstractDataset): diff --git a/tests/test_io.py b/tests/test_disk_io.py similarity index 76% rename from tests/test_io.py rename to tests/test_disk_io.py index 355ef3e..9e467cf 100644 --- a/tests/test_io.py +++ b/tests/test_disk_io.py @@ -86,3 +86,30 @@ def test_one_image(): test_item_image, test_item_name = test_dataset.__getitem__(1) assert test_item_name in DISK_CLASSES_FULL assert isinstance(test_item_image, torch.Tensor) + + +def test_get_loader_training_false(): + test_loader = DiskDataLoader( + pipeline=DISK_PIPELINE, + classes=DISK_CLASSES_FULL, + dataset_size=DATASET_SIZE_ALL, + training=False, + ) + test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + torch_loader = test_loader.get_loader(batch_size=64) + assert isinstance(torch_loader, torch.utils.data.DataLoader) + + +def test_get_loader_training_true(): + test_loader = DiskDataLoader( + pipeline=DISK_PIPELINE, + classes=DISK_CLASSES_FULL, + dataset_size=DATASET_SIZE_ALL, + training=True, + ) + test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + torch_loader_train, torch_loader_val = test_loader.get_loader( + split_size=0.8, batch_size=64 + ) + assert isinstance(torch_loader_train, torch.utils.data.DataLoader) + assert isinstance(torch_loader_val, torch.utils.data.DataLoader) From 8ff50240cb7c82c2e78d08b53df9357917769834 Mon Sep 17 00:00:00 2001 From: mooniean Date: Mon, 12 Feb 2024 12:04:04 +0000 Subject: [PATCH 2/2] added failing test. loader is no longer a class variable. --- src/caked/dataloader.py | 9 ++++----- tests/test_disk_io.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/caked/dataloader.py b/src/caked/dataloader.py index 94e4590..cbdd1a0 100644 --- a/src/caked/dataloader.py +++ b/src/caked/dataloader.py @@ -93,27 +93,26 @@ def get_loader(self, batch_size: int, split_size: float | None = None): train_data = Subset(self.dataset, indices=idx[:-s]) val_data = Subset(self.dataset, indices=idx[-s:]) - self.loader_train = DataLoader( + loader_train = DataLoader( train_data, batch_size=batch_size, num_workers=0, shuffle=True, ) - self.loader_val = DataLoader( + loader_val = DataLoader( val_data, batch_size=batch_size, num_workers=0, shuffle=True, ) - return self.loader_val, self.loader_train + return loader_val, loader_train - self.loader = DataLoader( + return DataLoader( self.dataset, batch_size=batch_size, num_workers=0, shuffle=True, ) - return self.loader class DiskDataset(AbstractDataset): diff --git a/tests/test_disk_io.py b/tests/test_disk_io.py index 9e467cf..d7df678 100644 --- a/tests/test_disk_io.py +++ b/tests/test_disk_io.py @@ -113,3 +113,17 @@ def test_get_loader_training_true(): ) assert isinstance(torch_loader_train, torch.utils.data.DataLoader) assert isinstance(torch_loader_val, torch.utils.data.DataLoader) + + +def test_get_loader_training_fail(): + test_loader = DiskDataLoader( + pipeline=DISK_PIPELINE, + classes=DISK_CLASSES_FULL, + dataset_size=DATASET_SIZE_ALL, + training=True, + ) + test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + with pytest.raises(Exception, match=r".* sets must be larger than .*"): + torch_loader_train, torch_loader_val = test_loader.get_loader( + split_size=1, batch_size=64 + )