Skip to content

Commit

Permalink
Processing/post: added function to fit the scale and bias of a recosn…
Browse files Browse the repository at this point in the history
…truction

Signed-off-by: Nicola VIGANO <nicola.vigano@esrf.fr>
  • Loading branch information
Obi-Wan committed Mar 13, 2024
1 parent 4ca71db commit b49d751
Showing 1 changed file with 49 additions and 11 deletions.
60 changes: 49 additions & 11 deletions corrct/processing/post.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
Post-processing routines.
Expand All @@ -9,20 +8,21 @@
"""

from collections.abc import Sequence
import numpy as np
import scipy as sp

from numpy.typing import ArrayLike, NDArray

from .misc import azimuthal_integration, lines_intersection, circular_mask
from typing import Optional

import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import numpy as np
import scipy as sp
from matplotlib.axes._axes import Axes
from matplotlib.figure import Figure
from numpy.typing import ArrayLike, NDArray
from scipy.optimize import minimize

from tqdm.auto import tqdm

from corrct.operators import BaseTransform, TransformIdentity
from corrct.processing.misc import azimuthal_integration, circular_mask, lines_intersection


eps = np.finfo(np.float32).eps

Expand Down Expand Up @@ -51,13 +51,13 @@ def com(vol: NDArray, axes: Optional[ArrayLike] = None) -> NDArray:
coords = [np.linspace(-(s - 1) / 2, (s - 1) / 2, s) for s in np.array(vol.shape)[list(axes)]]

num_dims = len(vol.shape)
com = np.empty((len(axes),))
center_of_mass = np.empty((len(axes),))
for ii, a in enumerate(axes):
sum_axes = np.array(np.delete(np.arange(num_dims), a), ndmin=1, dtype=int)
line = np.abs(vol).sum(axis=tuple(sum_axes))
com[ii] = line.dot(coords[ii]) / line.sum()
center_of_mass[ii] = line.dot(coords[ii]) / line.sum()

return com
return center_of_mass


def power_spectrum(
Expand Down Expand Up @@ -331,3 +331,41 @@ def plot_frcs(
plt.show(block=False)

return fig, axs


def fit_scale_bias(img_data: NDArray, prj_data: NDArray, prj: Optional[BaseTransform] = None) -> tuple[float, float]:
"""Fit the scale and bias of an image, against its projection in a different space.
Parameters
----------
img_data : NDArray
The image data
prj_data : NDArray
The projected data
prj : BaseTransform | None, optional
The projection operator. The default is None, which uses the identity (TransformIdentity)
Returns
-------
tuple[float, float]
The scale and bias
"""
if prj is None:
prj = TransformIdentity(img_data.shape)

prj_x = prj(img_data)
prj_1 = prj(np.ones_like(img_data))
m_y_dot_prj_x = -float(np.sum(prj_data * prj_x))
m_y_dot_prj_1 = -float(np.sum(prj_data * prj_1))
prj_x_2 = float(np.sum(prj_x**2))
prj_1_2 = float(np.sum(prj_1**2))
prj_1_dot_prj_x = float(np.sum(prj_1 * prj_x))

def obj_func(ab: NDArray) -> tuple[float, NDArray]:
residual = prj(img_data * ab[0] + ab[1]) - prj_data
grad_a = m_y_dot_prj_x + prj_x_2 * ab[0] + prj_1_dot_prj_x * ab[1]
grad_b = m_y_dot_prj_1 + prj_1_2 * ab[1] + prj_1_dot_prj_x * ab[0]
return float(np.linalg.norm(residual, ord=2) ** 2) / 2, np.array((grad_a, grad_b))

opt_res = minimize(obj_func, [1.0, 0.0], jac=True)
return float(opt_res.x[0]), float(opt_res.x[1])

0 comments on commit b49d751

Please sign in to comment.