From 2b3322df0075f6e9d3b2a83b54f9c751b9c1db40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicola=20VIGAN=C3=92?= Date: Wed, 24 Apr 2024 13:31:16 +0200 Subject: [PATCH] GI: Add support for generating more masks than pixels and fixed PEP08 compliance MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Nicola VIGANĂ’ --- corrct/struct_illum.py | 162 +++++++++++++++++++++++------------------ 1 file changed, 92 insertions(+), 70 deletions(-) diff --git a/corrct/struct_illum.py b/corrct/struct_illum.py index 292b990..97d4e70 100644 --- a/corrct/struct_illum.py +++ b/corrct/struct_illum.py @@ -173,10 +173,10 @@ def __init__( if mask_support is not None: self.mask_support = np.array(mask_support, ndmin=1, dtype=int) else: - self.mask_support = np.array(self.shape_FoV, ndmin=1, dtype=int) + self.mask_support = np.array(self.shape_fov, ndmin=1, dtype=int) @property - def shape_FoV(self) -> Sequence[int]: + def shape_fov(self) -> Sequence[int]: """Return the mask shape. Returns @@ -208,6 +208,17 @@ def num_buckets(self) -> int: """ return int(np.prod(self.shape_shifts)) + @property + def num_pixels(self) -> int: + """Compute the number of pixels in the image. + + Returns + ------- + int + The number of pixels in the image. + """ + return int(np.prod(self.shape_fov)) + def info(self) -> str: """ Return the mask info. @@ -345,7 +356,7 @@ def inspect_masks(self, mask_inds_vu: Union[None, Sequence[int], NDArrayInt] = N class MaskGenerator(ABC): """Define mask generation interface.""" - shape_FoV: NDArrayInt + shape_fov: NDArrayInt shape_mask: NDArrayInt shape_shifts: NDArrayInt transmittance: float @@ -357,7 +368,7 @@ class MaskGenerator(ABC): def __init__( self, - shape_FoV: Union[Sequence[int], NDArrayInt], + shape_fov: Union[Sequence[int], NDArrayInt], shape_mask: Union[Sequence[int], NDArrayInt], shape_shifts: Union[Sequence[int], NDArrayInt], transmittance: float = 1.0, @@ -367,7 +378,7 @@ def __init__( Parameters ---------- - shape_FoV : Sequence[int] | NDArray[np.integer] + shape_fov : Sequence[int] | NDArray[np.integer] The shape of the field-of-view. shape_mask : Sequence[int] | NDArray[np.integer] The shape of the masks. @@ -378,7 +389,7 @@ def __init__( dtype : DTypeLike The dtype of the created masks. """ - self.shape_FoV = np.array(shape_FoV, dtype=int) + self.shape_fov = np.array(shape_fov, dtype=int) self.shape_mask = np.array(shape_mask, dtype=int) self.shape_shifts = np.array(shape_shifts, dtype=int) @@ -410,20 +421,31 @@ def __repr__(self) -> str: return self.__class__.__name__ + " {\n" + ",\n".join([f" {k} = {v}" for k, v in self.__dict__.items()]) + "\n}" @property - def num_buckets(self) -> int: - """Compute the number of buckets. + def max_buckets(self) -> int: + """Compute the maximum number of buckets. Returns ------- int - The number of buckets. + The maximum number of buckets. """ return int(np.prod(self.shape_shifts)) - def _init_FoV_mm(self, FoV_size_mm: Union[float, Sequence[float], NDArray], req_res_mm: float) -> NDArrayInt: - self.FoV_size_mm = np.array(FoV_size_mm, ndmin=1) - num_points = np.ceil(self.FoV_size_mm / req_res_mm).astype(int) - self.feature_size_mm = self.FoV_size_mm / num_points + @property + def num_pixels(self) -> int: + """Compute the number of pixels in the image. + + Returns + ------- + int + The number of pixels in the image. + """ + return int(np.prod(self.shape_fov)) + + def _init_fov_mm(self, fov_size_mm: Union[float, Sequence[float], NDArray], req_res_mm: float) -> NDArrayInt: + self.fov_size_mm = np.array(fov_size_mm, ndmin=1) + num_points = np.ceil(self.fov_size_mm / req_res_mm).astype(int) + self.feature_size_mm = self.fov_size_mm / num_points return num_points @@ -450,36 +472,38 @@ def generate_collection(self, buckets_fraction: float = 1, shift_type: str = "se In case of wrong shift type. """ if shift_type.lower() == "random": - num_chosen_buckets = np.ceil(self.num_buckets * buckets_fraction).astype(int) + num_chosen_buckets = np.ceil(self.num_pixels * buckets_fraction).astype(int) disp_v, disp_u = self.get_random_shifts(num_chosen_buckets) elif shift_type.lower() == "interval": interval = np.ceil(1 / buckets_fraction).astype(int) disp_v, disp_u = self.get_interval_shifts(interval) num_chosen_buckets = len(disp_v) elif shift_type.lower() == "sequential": - num_chosen_buckets = np.ceil(self.num_buckets * buckets_fraction).astype(int) + num_chosen_buckets = np.ceil(self.num_pixels * buckets_fraction).astype(int) disp_v, disp_u = self.get_sequential_shifts(num_chosen_buckets) else: - raise ValueError('Wrong shift_type: "%s". Available options: {random} | interval | sequential.' % shift_type) + raise ValueError(f'Wrong shift_type: "{shift_type}". Available options: random | interval | sequential.') - gen_masks_enc = self._generate_masks(disp_v, disp_u, mask_encoding=True) - gen_masks_dec = self._generate_masks(disp_v, disp_u, mask_encoding=False) if self._enc_dec_mismatch else None + gen_masks_enc = self._generate_mask_shifts(disp_v, disp_u, mask_encoding=True) + gen_masks_dec = self._generate_mask_shifts(disp_v, disp_u, mask_encoding=False) if self._enc_dec_mismatch else None - print("Using %d masks over %d" % (num_chosen_buckets, self.num_buckets)) + print( + f"Using {num_chosen_buckets} masks/buckets over {self.max_buckets} (max available), and {self.num_pixels} pixels" + ) return MaskCollection( gen_masks_enc, gen_masks_dec, mask_type=self.__mask_name__, mask_support=self.shape_mask, - mask_dims=len(self.shape_FoV), + mask_dims=len(self.shape_fov), ) def _apply_transmission(self, masks: NDArray) -> NDArray: return 1 - (1 - masks) * self.transmittance @abstractmethod - def generate_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: bool = True) -> NDArray: + def generate_shifted_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: bool = True) -> NDArray: """Produce the shifted masks. Parameters @@ -495,7 +519,7 @@ def generate_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: b The shifted mask. """ - def _generate_masks( + def _generate_mask_shifts( self, shifts_v: Union[Sequence, NDArray], shifts_u: Union[Sequence, NDArray], mask_encoding: bool = True ) -> NDArray: """Produce all the masks. @@ -516,8 +540,8 @@ def _generate_masks( NDArray The collection of all the shifted masks. """ - masks = [self.generate_mask([v, u], mask_encoding) for v, u in zip(shifts_v, shifts_u)] - return np.stack(masks, axis=0) # .reshape([*self.shift_shape, *self.FoV_shape]) + masks = [self.generate_shifted_mask([v, u], mask_encoding) for v, u in zip(shifts_v, shifts_u)] + return np.stack(masks, axis=0) # .reshape([*self.shift_shape, *self.fov_shape]) def get_interval_shifts( self, interval: Union[int, Sequence[int], NDArray], axes_order: Sequence[int] = (-2, -1) @@ -557,11 +581,9 @@ def get_random_shifts(self, num_shifts: int, axes_order: Sequence[int] = (-2, -1 NDArray The collection of shifts. """ - max_disps = np.prod(self.shape_shifts) - - if num_shifts > max_disps: - print("Warning, too many shifts. Truncating to: %d" % max_disps) - num_shifts = np.fmin(num_shifts, max_disps) + if num_shifts > self.max_buckets: + print(f"Warning, too many shifts. Truncating to: {self.max_buckets}") + num_shifts = np.fmin(num_shifts, self.max_buckets) disps = self.get_interval_shifts(interval=1, axes_order=axes_order) perms = np.random.permutation(np.prod(self.shape_shifts))[:num_shifts] @@ -595,20 +617,20 @@ class MaskGeneratorPoint(MaskGenerator): __mask_name__ = "pencil" - def __init__(self, FoV_size_mm: Union[float, Sequence[float], NDArray], req_res_mm: float = 1.0): + def __init__(self, fov_size_mm: Union[float, Sequence[float], NDArray], req_res_mm: float = 1.0): """Initialize the pencil beam mask collection. Parameters ---------- - FoV_size_mm : float + fov_size_mm : float Size of the Field-of-View in millimiters. req_res_mm : float Requested resolution in millimiters. """ - num_points = self._init_FoV_mm(FoV_size_mm, req_res_mm) + num_points = self._init_fov_mm(fov_size_mm, req_res_mm) super().__init__(num_points, [1, 1], num_points) - def generate_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: bool = True) -> NDArray: + def generate_shifted_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: bool = True) -> NDArray: """Produce the shifted masks. Parameters @@ -623,7 +645,7 @@ def generate_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: b NDArray The shifted mask. """ - mask = np.zeros(self.shape_FoV, dtype=self.dtype) + mask = np.zeros(self.shape_fov, dtype=self.dtype) mask[mask_inds_vu[0], mask_inds_vu[1]] = 1.0 return self._apply_transmission(mask) @@ -633,7 +655,7 @@ class MaskGeneratorBernoulli(MaskGenerator): __mask_name__ = "bernoulli" - def __init__(self, FoV_size_mm: Union[float, Sequence[float], NDArray], req_res_mm: float = 1.0): + def __init__(self, fov_size_mm: Union[float, Sequence[float], NDArray], req_res_mm: float = 1.0): """ Bernulli masks collection class. @@ -641,15 +663,15 @@ def __init__(self, FoV_size_mm: Union[float, Sequence[float], NDArray], req_res_ Parameters ---------- - FoV_size_mm : float - DESCRIPTION. + fov_size_mm : float + Size of the Field-of-View in mm. req_res_mm : float - DESCRIPTION. + The required pixel size in mm. """ - num_points = self._init_FoV_mm(FoV_size_mm, req_res_mm) - super().__init__(num_points, num_points, num_points) + num_points = self._init_fov_mm(fov_size_mm, req_res_mm) + super().__init__(shape_fov=num_points, shape_mask=num_points, shape_shifts=np.ceil(num_points * 1.2).astype(int)) - def generate_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: bool = True) -> NDArray: + def generate_shifted_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: bool = True) -> NDArray: """Produce the shifted masks. Parameters @@ -664,7 +686,7 @@ def generate_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: b NDArray The shifted mask. """ - mask = np.random.randint(0, 2, size=self.shape_FoV).astype(self.dtype) + mask = np.random.randint(0, 2, size=self.shape_fov).astype(self.dtype) return self._apply_transmission(mask) @@ -673,7 +695,7 @@ class MaskGeneratorHalfGaussian(MaskGenerator): __mask_name__ = "half-gaussian" - def __init__(self, FoV_size_mm: Union[float, Sequence[float], NDArray], req_res_mm: float = 1.0): + def __init__(self, fov_size_mm: Union[float, Sequence[float], NDArray], req_res_mm: float = 1.0): """ Half Gaussian masks collection class. @@ -681,15 +703,15 @@ def __init__(self, FoV_size_mm: Union[float, Sequence[float], NDArray], req_res_ Parameters ---------- - FoV_size_mm : float - DESCRIPTION. + fov_size_mm : float + Size of the Field-of-View in mm. req_res_mm : float - DESCRIPTION. + The required pixel size in mm. """ - num_points = self._init_FoV_mm(FoV_size_mm, req_res_mm) - super().__init__(num_points, num_points, num_points) + num_points = self._init_fov_mm(fov_size_mm, req_res_mm) + super().__init__(shape_fov=num_points, shape_mask=num_points, shape_shifts=np.ceil(num_points * 1.2).astype(int)) - def generate_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: bool = True) -> NDArray: + def generate_shifted_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: bool = True) -> NDArray: """Produce the shifted masks. Parameters @@ -704,7 +726,7 @@ def generate_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: b NDArray The shifted mask. """ - mask = np.abs(np.random.randn(*self.shape_FoV)).astype(self.dtype) + mask = np.abs(np.random.randn(*self.shape_fov)).astype(self.dtype) return self._apply_transmission(mask) @@ -713,21 +735,21 @@ class MaskGeneratorMURA(MaskGenerator): __mask_name__ = "mura" - def __init__(self, FoV_size_mm: float, req_res_mm: float = 1.0): + def __init__(self, fov_size_mm: float, req_res_mm: float = 1.0): """ MURA masks collection class. Parameters ---------- - FoV_size_mm : float - DESCRIPTION. + fov_size_mm : float + Size of the Field-of-View in mm. req_res_mm : float - DESCRIPTION. + The required pixel size in mm. """ - self.FoV_size_mm = np.array([FoV_size_mm, FoV_size_mm]) - base_points = int(np.ceil((FoV_size_mm / req_res_mm - 1) / 4)) + self.fov_size_mm = np.array([fov_size_mm, fov_size_mm]) + base_points = int(np.ceil((fov_size_mm / req_res_mm - 1) / 4)) num_points = 4 * base_points + 1 - self.feature_size_mm = self.FoV_size_mm / num_points + self.feature_size_mm = self.fov_size_mm / num_points sq = np.mod(np.arange(num_points) ** 2, num_points) @@ -744,7 +766,7 @@ def __init__(self, FoV_size_mm: float, req_res_mm: float = 1.0): super().__init__([num_points, num_points], [num_points, num_points], [num_points, num_points]) self._enc_dec_mismatch = True - def generate_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: bool = True) -> NDArray: + def generate_shifted_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: bool = True) -> NDArray: """Produce the shifted masks. Parameters @@ -770,7 +792,7 @@ def generate_mask(self, mask_inds_vu: Union[Sequence, NDArray], mask_encoding: b return self._apply_transmission(mask) @staticmethod - def compute_possible_mask_sizes(FoV_size: int) -> NDArray: + def compute_possible_mask_sizes(fov_size: int) -> NDArray: """Compute MURA masks sizes. MURA masks require specific edge sizes: prime numbers _x_ that also @@ -778,8 +800,8 @@ def compute_possible_mask_sizes(FoV_size: int) -> NDArray: Parameters ---------- - FoV_size : int - Edge size of the FoV in pixels. + fov_size : int + Edge size of the fov in pixels. Returns ------- @@ -791,7 +813,7 @@ def test_prime(x): div_val = x % np.arange(2, x // 2) return not np.any(div_val == 0) - max_possible_val = (FoV_size - 1) // 4 + max_possible_val = (fov_size - 1) // 4 test_values = np.arange(1, max_possible_val) * 4 + 1 primes = np.array([test_prime(x) for x in test_values]) return test_values[primes] @@ -815,7 +837,7 @@ def __init__(self, mask_collection: MaskCollection): self._axes_shifts = np.arange(len(self.mc.shape_shifts)) self.col_sum = np.abs(self.mc.masks_dec).sum(axis=tuple(self._axes_shifts)) - self._axes_fov = np.arange(-len(self.mc.shape_FoV), 0) + self._axes_fov = np.arange(-len(self.mc.shape_fov), 0) self.row_sum = np.sqrt(np.abs(self.mc.masks_enc * self.mc.masks_dec)).sum(axis=tuple(self._axes_fov)) self.vol_shape = self.mc.masks_enc.shape[-2:] @@ -835,7 +857,7 @@ def fp(self, image: NDArray) -> NDArray: NDArray The predicted bucket values. """ - masks_shape = [np.prod(self.mc.shape_shifts), np.prod(self.mc.shape_FoV)] + masks_shape = [np.prod(self.mc.shape_shifts), np.prod(self.mc.shape_fov)] image_shape = [*image.shape[: -self.mc.mask_dims], np.prod(image.shape[-self.mc.mask_dims :])] return np.squeeze(image.reshape(image_shape).dot(self.mc.masks_enc.reshape(masks_shape).T)) @@ -856,8 +878,8 @@ def bp(self, bucket_vals: NDArray) -> NDArray: NDArray Back-projected image. """ - masks_shape = [np.prod(self.mc.shape_shifts), np.prod(self.mc.shape_FoV)] - out_shape = [*bucket_vals.shape[:-1], *self.mc.shape_FoV] + masks_shape = [np.prod(self.mc.shape_shifts), np.prod(self.mc.shape_fov)] + out_shape = [*bucket_vals.shape[:-1], *self.mc.shape_fov] masks_in = self.mc.masks_dec.reshape(masks_shape) img_out: NDArray = bucket_vals.dot(masks_in) @@ -878,7 +900,7 @@ def adjust_sampling_scaling(self, image: NDArray) -> NDArray: NDArray Scaled image. """ - sampling_ratio = self.mc.num_buckets / np.prod(self.mc.shape_FoV) + sampling_ratio = self.mc.num_buckets / np.prod(self.mc.shape_fov) image_f: NDArray = np.fft.rfftn(image, axes=tuple(self._axes_fov), norm="ortho") rec_f_shape = np.array(image_f.shape)[list(self._axes_fov)] @@ -887,7 +909,7 @@ def adjust_sampling_scaling(self, image: NDArray) -> NDArray: sampling_filter[0, 0] = 1 image_f *= sampling_filter - return np.fft.irfftn(image_f, s=self.mc.shape_FoV, axes=tuple(self._axes_fov), norm="ortho") + return np.fft.irfftn(image_f, s=self.mc.shape_fov, axes=tuple(self._axes_fov), norm="ortho") def fbp(self, bucket_vals: NDArray, use_lstsq: bool = True, adjust_scaling: bool = True) -> NDArray: """Compute cross-correlation reconstruction of the bucket values. @@ -906,10 +928,10 @@ def fbp(self, bucket_vals: NDArray, use_lstsq: bool = True, adjust_scaling: bool if np.any(self.mc.masks_enc != self.mc.masks_dec) and not use_lstsq: img = self.bp(bucket_vals) else: - masks = self.mc.masks_enc.reshape([-1, np.prod(self.mc.shape_FoV)]) + masks = self.mc.masks_enc.reshape([-1, np.prod(self.mc.shape_fov)]) result = spalg.lstsq(masks, bucket_vals) img: NDArray = result[0] - img = img.reshape(self.mc.shape_FoV) + img = img.reshape(self.mc.shape_fov) if adjust_scaling: img = self.adjust_sampling_scaling(img)