Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding MapDataset and MapDataloader to Caked #35

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
c64fbff
Updates to gitignore
hllelli2 Jun 19, 2024
2f4d58f
basic setup config for pip install
hllelli2 Jun 19, 2024
6afd99e
Added MapDataset and MapDataLoader
hllelli2 Jun 19, 2024
ef73dd2
Added custom augments and transfroms as well as tiling
hllelli2 Jun 19, 2024
2ab4260
Added tests for new code
hllelli2 Jun 19, 2024
8b9d06d
feat: Add HDF5DataStore class for handling hdf5 files
hllelli2 Jul 2, 2024
d8729a1
Add pytest configuration file and conftest.py for test fixtures
hllelli2 Jul 4, 2024
1a6ed24
Added a dataset config class to have all the defaults in one place
hllelli2 Jul 9, 2024
c8486b8
Create a class for lazy-loading a hdf5 file
hllelli2 Jul 9, 2024
3779ee4
Moved multi-processing to utils, Added an array dataset which can loa…
hllelli2 Jul 9, 2024
b9a91cb
Moved Multi-processing code here, code here to duplciate arrays from …
hllelli2 Jul 9, 2024
ab2ce2f
Changed augment to take arrays instead of map-objects to make it comp…
hllelli2 Jul 9, 2024
d6eda3a
Removed mapObject transform base
hllelli2 Jul 9, 2024
a86c910
chore: Refactor transforms module and update transform classes
hllelli2 Jul 9, 2024
41d40b3
first attempt at reducing if label/weight... is not none repeated code
hllelli2 Jul 9, 2024
6f73aae
chore: Add test fixtures and update conftest.py for test setup
hllelli2 Jul 9, 2024
07e2318
set_gpu is not in the current ccpem-utils so check added JIC
hllelli2 Jul 10, 2024
a6cc9b1
chore: Refactor HDF5DataStore class and add support for temporary dir…
hllelli2 Jul 17, 2024
05f792b
Removed "_" id logic in MapDataLoader, might reimplement later but no…
hllelli2 Jul 17, 2024
5a770f7
Refactor process_datasets function and remove unnecessary code
hllelli2 Jul 17, 2024
0a99fd7
Refactor transforms module and update transform classes
hllelli2 Jul 17, 2024
8772fa6
re-added __del__ method to HDF5Store
hllelli2 Jul 17, 2024
93c8424
Refactor MapDataset and ArrayDataset classes to handle weight tensors
hllelli2 Sep 4, 2024
9128e9a
Refactor
hllelli2 Sep 4, 2024
b2a49c1
Refactor DecomposeToSlices, MapObjectMaskCrop, and MapObjectPadding c…
hllelli2 Sep 4, 2024
ce47030
Refactor test_map_io.py to handle weight tensors and update dataset l…
hllelli2 Sep 4, 2024
6ff4461
Merge branch 'main' into CCPEM-AddingMapDataset
hllelli2 Oct 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,7 @@ Thumbs.db
# Common editor files
*~
*.swp


# IDE specific files
.vscode/
36 changes: 36 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Setup configuration for the package
[metadata]
name = caked


# Options for the package

[options]

packages = find:
python_requires = >=3.8





[options.packages.find]
where = src
exclude =
tests
.github
.gitignore
.gitattributes
.pytest_cache
.git
.vscode
.history
*.egg
*.egg-info
docs
site
mkdocs.yml
*.ipynb
.mypy_cache
.ruff_cache

3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from setuptools import setup

setup()
122 changes: 122 additions & 0 deletions src/caked/Transforms/augments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

import random
from enum import Enum

import numpy as np
from ccpem_utils.map.array_utils import rotate_array
from ccpem_utils.map.parse_mrcmapobj import MapObjHandle

from .base import AugmentBase


class Augments(Enum):
""" """

RANDOMROT = "randrot"
ROT90 = "rot90"


def get_augment(augment: str, random_seed) -> AugmentBase:
""" """

if augment == Augments.RANDOMROT.value:
return RandomRotationAugment(random_seed=random_seed)
if augment == Augments.ROT90.value:
return Rotation90Augment(random_seed=random_seed)

msg = f"Unknown Augmentation: {augment}, please choose from {Augments.__members__}"
raise ValueError(msg)


class ComposeAugment:
"""
Compose multiple Augments together.

:param augments: (list) list of augments to compose

:return: (np.ndarrry) transformed array
"""

def __init__(self, augments: list[str], random_seed: int = 42):
self.random_seed = random_seed
self.augments = augments

def __call__(self, data: np.ndarray, **kwargs) -> MapObjHandle:
for augment in self.augments:
data, augment_kwargs = get_augment(augment, random_seed=self.random_seed)(
data, **kwargs
)

kwargs.update(augment_kwargs)

return data, kwargs


class RandomRotationAugment(AugmentBase):
"""
Random or controlled rotation (if ax and an kwargs provided).

:param data: (np.ndarray) 3d volume
:param return_all: (bool) if True, will parameters of the rotation (ax, an)
:param interp: (bool) if True, will interpolate the rotation
:param ax: (int) 0 for yaw, 1 for pitch, 2 for roll
:param an: (int) number of times to rotate, between <1 and 3>

:return: (np.ndarray) rotated volume or (np.ndarray, int, int) rotated volume and rotation parameters
"""

def __init__(self, random_seed: int = 42):
super().__init__(random_seed)

def __call__(
self,
data: np.ndarray,
**kwargs,
) -> np.ndarray | tuple[np.ndarray, int, int]:
ax = kwargs.get("ax", None)
an = kwargs.get("an", None)
interp = kwargs.get("interp", True)

if (ax is not None and an is None) or (ax is None and an is not None):
msg = "When specifying rotation, please use both arguments to specify the axis and angle."
raise RuntimeError(msg)
rotations = [(0, 1), (0, 2), (1, 2)] # yaw, pitch, roll
if ax is None and an is None:
axes = random.randint(0, 2)
set_angles = [30, 60, 90]
angler = random.randint(0, 2)
angle = set_angles[angler]
else:
axes = ax
angle = an

r = rotations[axes]
data = rotate_array(data, angle, axes=r, interpolate=interp, reshape=False)

return data, {"ax": axes, "an": angle}


class Rotation90Augment(AugmentBase):
"""
Rotate the volume by 90 degrees.

:param data: (np.ndarray) 3d volume
:param return_all: (bool) if True, will parameters of the rotation (ax, an)
:param interp: (bool) if True, will interpolate the rotation
:param ax: (int) 0 for yaw, 1 for pitch, 2 for roll
:param an: (int) number of times to rotate, between <1 and 3>

:return: (np.ndarray) rotated volume or (np.ndarray, int, int) rotated volume and rotation parameters
"""

def __init__(self, random_seed: int = 42):
super().__init__(random_seed)

def __call__(
self,
data: np.ndarray,
**kwargs,
) -> np.ndarray:
msg = "Rotation90Augment not implemented yet."
raise NotImplementedError(msg)
39 changes: 39 additions & 0 deletions src/caked/Transforms/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

from abc import ABC, abstractmethod

import numpy as np
from ccpem_utils.map.parse_mrcmapobj import MapObjHandle


class TransformBase(ABC):
"""
Base class for transformations.

"""

@abstractmethod
def __init__(self):
pass

@abstractmethod
def __call__(self, data):
msg = "The __call__ method must be implemented in the subclass"
raise NotImplementedError(msg)


class AugmentBase(ABC):
"""
Base class for augmentations.
"""

# This will need to take the hyper parameters for the augmentations

@abstractmethod
def __init__(self, random_seed: int = 42):
self.random_state = np.random.RandomState(random_seed)

@abstractmethod
def __call__(self, data, **kwargs):
msg = "The __call__ method must be implemented in the subclass"
raise NotImplementedError(msg)
Loading
Loading