|
| 1 | +# __init__.py |
| 2 | +# This module defines custom transformations for image convolution. |
| 3 | +# Author: Brian Recktenwall-Calvet |
| 4 | +# Date: 12-18-2024 |
| 5 | +# Version: 1.0 |
| 6 | + |
1 | 7 | import os
|
2 | 8 | import pathlib
|
3 | 9 | import mosaique.models.kernels
|
4 | 10 | import numpy as np
|
5 |
| -from typing import Any, Callable |
| 11 | +from typing import Any, List, Optional |
| 12 | + |
6 | 13 |
|
7 | 14 | class ConvolutionLayer4x4:
|
| 15 | + """ |
| 16 | + A class representing a 4x4 convolution layer using a specified kernel. |
| 17 | +
|
| 18 | + Attributes: |
| 19 | + _name (str): The name of the convolution layer. |
| 20 | + _kernel (kernels.Kernel2d4x4): An instance of the kernel used for convolution operations. |
| 21 | +
|
| 22 | + Methods: |
| 23 | + fit(dataset): Fits the kernel to the provided dataset. |
| 24 | + transform(dataset): Applies the convolution transformation to the dataset. |
| 25 | + post_transform(dataset): Applies post-processing to the dataset after transformation. |
| 26 | + channel_merge(dataset): Merges channels in the dataset. |
| 27 | + save(dataset, variant): Saves a portion of the dataset to a .npy file based on the variant. |
| 28 | + open(variant): Loads a saved dataset from a .npy file based on the variant. |
| 29 | + """ |
| 30 | + |
8 | 31 | _name: str
|
9 | 32 | _kernel: kernels.Kernel2d4x4
|
10 | 33 |
|
11 |
| - |
12 | 34 | @property
|
13 |
| - def name(self) -> str : |
| 35 | + def name(self) -> str: |
| 36 | + """Gets the name of the convolution layer.""" |
14 | 37 | return self._name
|
| 38 | + |
15 | 39 | @name.setter
|
16 | 40 | def name(self, value: str):
|
| 41 | + """Sets the name of the convolution layer.""" |
17 | 42 | self._name = value
|
| 43 | + |
18 | 44 | @property
|
19 | 45 | def kernel(self) -> kernels.Kernel2d4x4:
|
| 46 | + """Gets the kernel used for convolution operations.""" |
20 | 47 | return self._kernel
|
| 48 | + |
21 | 49 | @kernel.setter
|
22 | 50 | def kernel(self, value: kernels.Kernel2d4x4):
|
| 51 | + """Sets the kernel used for convolution operations.""" |
23 | 52 | self._kernel = value
|
24 | 53 |
|
25 |
| - def __init__(self, name: str, kernel_shape:[int]=None): |
| 54 | + def __init__(self, name: str, kernel_shape: Optional[List[int]] = None): |
| 55 | + """ |
| 56 | + Initializes the ConvolutionLayer4x4 with a given name and kernel shape. |
| 57 | +
|
| 58 | + Args: |
| 59 | + name (str): The name of the convolution layer. |
| 60 | + kernel_shape (Optional[List[int]]): The shape of the kernel (default is [2, 2]). |
| 61 | +
|
| 62 | + Raises: |
| 63 | + ValueError: If kernel_shape is not a valid shape for the kernel. |
| 64 | + """ |
26 | 65 | if kernel_shape is None:
|
27 | 66 | kernel_shape = [2, 2]
|
28 | 67 | self.name = name
|
29 | 68 | self.kernel = kernels.Kernel2d4x4(kernel_shape)
|
30 | 69 |
|
31 | 70 | def fit(self, dataset: np.ndarray[..., np.dtype[Any]]):
|
| 71 | + """ |
| 72 | + Fits the kernel to the provided dataset. |
| 73 | +
|
| 74 | + Args: |
| 75 | + dataset (np.ndarray): The input dataset used to fit the kernel. |
| 76 | + """ |
32 | 77 | self.kernel.fit(dataset)
|
33 | 78 |
|
34 |
| - def transform(self, dataset: np.ndarray[..., np.dtype[Any]]): |
| 79 | + def transform(self, dataset: np.ndarray[..., np.dtype[Any]]) -> np.ndarray: |
| 80 | + """ |
| 81 | + Applies the convolution transformation to the dataset. |
| 82 | +
|
| 83 | + Args: |
| 84 | + dataset (np.ndarray): The input dataset to transform. |
| 85 | +
|
| 86 | + Returns: |
| 87 | + np.ndarray: The transformed dataset after applying the convolution. |
| 88 | + """ |
35 | 89 | return self.kernel.transform(dataset)
|
36 | 90 |
|
37 |
| - def post_transform(self, dataset: np.ndarray[..., np.dtype[Any]]): |
| 91 | + def post_transform(self, dataset: np.ndarray[..., np.dtype[Any]]) -> np.ndarray: |
| 92 | + """ |
| 93 | + Applies post-processing to the dataset after transformation. |
| 94 | +
|
| 95 | + Args: |
| 96 | + dataset (np.ndarray): The input dataset to post-process. |
| 97 | +
|
| 98 | + Returns: |
| 99 | + np.ndarray: The post-processed dataset. |
| 100 | + """ |
38 | 101 | return self.kernel.post_transform(dataset)
|
39 |
| - def channel_merge(self, dataset: np.ndarray[..., np.dtype[Any]]): |
| 102 | + |
| 103 | + def channel_merge(self, dataset: np.ndarray[..., np.dtype[Any]]) -> np.ndarray: |
| 104 | + """ |
| 105 | + Merges channels in the dataset, typically for multi-channel inputs. |
| 106 | +
|
| 107 | + Args: |
| 108 | + dataset (np.ndarray): The input dataset whose channels will be merged. |
| 109 | +
|
| 110 | + Returns: |
| 111 | + np.ndarray: The dataset with merged channels. |
| 112 | + """ |
40 | 113 | return self.kernel.channel_merge(dataset)
|
41 | 114 |
|
42 |
| - def save(self, dataset: np.ndarray[..., np.dtype[Any]], variant:[int]): |
43 |
| - variant_string = ''.join(map(str,variant)) |
| 115 | + def save(self, dataset: np.ndarray[..., np.dtype[Any]], variant: List[int]): |
| 116 | + """ |
| 117 | + Saves a portion of the dataset to a .npy file based on the variant. |
| 118 | +
|
| 119 | + Args: |
| 120 | + dataset (np.ndarray): The dataset to save. |
| 121 | + variant (List[int]): A list of indices specifying which channels to save. |
| 122 | + """ |
| 123 | + variant_string = ''.join(map(str, variant)) |
44 | 124 | workdir = str(pathlib.Path().resolve()) + "/" + self.name
|
45 | 125 | os.makedirs(workdir, exist_ok=True)
|
46 |
| - np.save(os.path.join(workdir, variant_string), dataset[:,:,:,variant]) |
| 126 | + np.save(os.path.join(workdir, variant_string), dataset[:, :, :, variant]) |
| 127 | + |
| 128 | + def open(self, variant: List[int]) -> np.ndarray: |
| 129 | + """ |
| 130 | + Loads a saved dataset from a .npy file based on the variant. |
47 | 131 |
|
48 |
| - def open(self, variant:[int]): |
49 |
| - variant_string = ''.join(map(str,variant)) |
| 132 | + Args: |
| 133 | + variant (List[int]): A list of indices specifying which channels to load. |
| 134 | +
|
| 135 | + Returns: |
| 136 | + np.ndarray: The loaded dataset. |
| 137 | +
|
| 138 | + Raises: |
| 139 | + FileNotFoundError: If the specified file does not exist. |
| 140 | + """ |
| 141 | + variant_string = ''.join(map(str, variant)) |
50 | 142 | workdir = str(pathlib.Path().resolve()) + "/" + self.name
|
51 | 143 | return np.load(os.path.join(workdir, variant_string + '.npy'))
|
52 |
| - |
|
0 commit comments