Skip to content

Commit

Permalink
added failing test. loader is no longer a class variable.
Browse files Browse the repository at this point in the history
  • Loading branch information
mooniean committed Feb 12, 2024
1 parent 48bfe1d commit 8ff5024
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/caked/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,27 +93,26 @@ def get_loader(self, batch_size: int, split_size: float | None = None):
train_data = Subset(self.dataset, indices=idx[:-s])
val_data = Subset(self.dataset, indices=idx[-s:])

self.loader_train = DataLoader(
loader_train = DataLoader(
train_data,
batch_size=batch_size,
num_workers=0,
shuffle=True,
)
self.loader_val = DataLoader(
loader_val = DataLoader(
val_data,
batch_size=batch_size,
num_workers=0,
shuffle=True,
)
return self.loader_val, self.loader_train
return loader_val, loader_train

self.loader = DataLoader(
return DataLoader(
self.dataset,
batch_size=batch_size,
num_workers=0,
shuffle=True,
)
return self.loader


class DiskDataset(AbstractDataset):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_disk_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,17 @@ def test_get_loader_training_true():
)
assert isinstance(torch_loader_train, torch.utils.data.DataLoader)
assert isinstance(torch_loader_val, torch.utils.data.DataLoader)


def test_get_loader_training_fail():
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL,
dataset_size=DATASET_SIZE_ALL,
training=True,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
with pytest.raises(Exception, match=r".* sets must be larger than .*"):
torch_loader_train, torch_loader_val = test_loader.get_loader(
split_size=1, batch_size=64
)

0 comments on commit 8ff5024

Please sign in to comment.