Skip to content

Commit

Permalink
Cone-beam: refactored Ellipse class, and reorganized visualizer
Browse files Browse the repository at this point in the history
Signed-off-by: Nicola VIGANO <nicola.vigano@esrf.fr>
  • Loading branch information
Obi-Wan committed Mar 27, 2024
1 parent 5260782 commit 7c493ff
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 238 deletions.
158 changes: 14 additions & 144 deletions corrct/alignment/cone_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ def __init__(
self._pre_fit()

def _pre_fit(self, use_least_squares: bool = False) -> None:
ell1_acq = fitting.Ellipse(self.points_ell1)
ell2_acq = fitting.Ellipse(self.points_ell2)
self.ell1_acq = fitting.Ellipse(self.points_ell1, least_squares=use_least_squares)
self.ell2_acq = fitting.Ellipse(self.points_ell2, least_squares=use_least_squares)

if self.points_axis is not None:
# Using measured projected center, whenever available
Expand All @@ -293,8 +293,8 @@ def _pre_fit(self, use_least_squares: bool = False) -> None:

self.prj_origin_vu = self.points_axis[:, 1]
else:
self.ell1_prj_center_vu = ell1_acq.fit_prj_center(least_squares=use_least_squares)
self.ell2_prj_center_vu = ell2_acq.fit_prj_center(least_squares=use_least_squares)
self.ell1_prj_center_vu = self.ell1_acq.center_vu
self.ell2_prj_center_vu = self.ell2_acq.center_vu

self.prj_origin_vu = None

Expand All @@ -315,30 +315,25 @@ def _pre_fit(self, use_least_squares: bool = False) -> None:
self.points_ell2_rot = self.points_ell2.copy()

# Re-instatiate ellipse class, after rotation
ell1_rot = fitting.Ellipse(self.points_ell1_rot)
ell2_rot = fitting.Ellipse(self.points_ell2_rot)

self.ell1_params = ell1_rot.fit_parameters(least_squares=use_least_squares)
self.ell2_params = ell2_rot.fit_parameters(least_squares=use_least_squares)
self.ell1_rot = fitting.Ellipse(self.points_ell1_rot, least_squares=use_least_squares)
self.ell2_rot = fitting.Ellipse(self.points_ell2_rot, least_squares=use_least_squares)

if self.plot_result:
fig, axs = plt.subplots()
axs.plot(self.points_ell1[1, :], self.points_ell1[0, :], "C0--", label="Ellipse 1 - Acquired")
axs.plot(self.points_ell2[1, :], self.points_ell2[0, :], "C1--", label="Ellipse 2 - Acquired")
axs.plot(self.points_ell1_rot[1, :], self.points_ell1_rot[0, :], "C0", label="Ellipse 1 - Rotated")
axs.plot(self.points_ell2_rot[1, :], self.points_ell2_rot[0, :], "C1", label="Ellipse 2 - Rotated")
ell1_acq_params = ell1_acq.fit_parameters(least_squares=use_least_squares)
ell2_acq_params = ell2_acq.fit_parameters(least_squares=use_least_squares)
axs.plot([ell1_acq_params[-1], ell2_acq_params[-1]], [ell1_acq_params[-2], ell2_acq_params[-2]], "C2--")
axs.plot([self.ell1_params[-1], self.ell2_params[-1]], [self.ell1_params[-2], self.ell2_params[-2]], "C2")
axs.plot([self.ell1_acq.u, self.ell2_acq.u], [self.ell1_acq.v, self.ell2_acq.v], "C2--")
axs.plot([self.ell1_rot.u, self.ell2_rot.u], [self.ell1_rot.v, self.ell2_rot.v], "C2")
if self.points_axis is not None:
axs.scatter(self.points_axis[1], self.points_axis[0], c="C2", marker="*", label="Centers - Acquired")
axs.legend()
axs.grid()
fig.tight_layout()
plt.show(block=False)

self.acq_geom.D = self._fit_distance_det2src(self.ell1_params, self.ell2_params)
self.acq_geom.D = self._fit_distance_det2src(self.ell1_rot, self.ell2_rot)

if self.verbose:
print(f"Fitted detector distance from source (pix): {self.acq_geom.D}")
Expand All @@ -359,8 +354,8 @@ def fit(self, r: float, e: float = 1) -> ConeBeamGeometry:
ValueError
In case of flipped ellipses.
"""
b1, a1, c1, v1, u1 = self.ell1_params
b2, a2, c2, v2, u2 = self.ell2_params
b1, a1, c1, v1, u1 = self.ell1_rot.parameters
b2, a2, c2, v2, u2 = self.ell2_rot.parameters

sign_z1 = -1
sign_z2 = sign_z1 * -e
Expand Down Expand Up @@ -430,11 +425,9 @@ def get_zeta(bk, ak, ck, D, sign_zk) -> float:
return self.acq_geom

@staticmethod
def _fit_distance_det2src(
ellipse_1: Union[ArrayLike, NDArray], ellipse_2: Union[ArrayLike, NDArray], e: float = 1
) -> float:
b1, a1, c1, v1, _ = np.array(ellipse_1)
b2, a2, c2, v2, _ = np.array(ellipse_2)
def _fit_distance_det2src(ellipse_1: fitting.Ellipse, ellipse_2: fitting.Ellipse, e: float = 1) -> float:
b1, a1, c1, v1, _ = ellipse_1.parameters
b2, a2, c2, v2, _ = ellipse_2.parameters

ecc2 = np.sqrt(b2 - c2**2 / a2)
ecc1 = np.sqrt(b1 - c1**2 / a1)
Expand Down Expand Up @@ -522,126 +515,3 @@ def tune_acquisition_geometry(
acq_geom_tuned = acq_geom_tuned.update(par_name, min_par)

return acq_geom_tuned


def cm2inch(dims: Union[ArrayLike, NDArray]) -> tuple[float]:
"""Convert cm into inch.
Parameters
----------
dims : Union[ArrayLike, NDArray]
The dimentions of the object in cm
Returns
-------
tuple[float]
The output dimensions in inch
"""
return tuple(np.array(dims) / 2.54)


class MarkerVisualizer:
"""Plotting class to assess the calibration quality."""

def __init__(
self,
fitted_positions_vu: Union[ArrayLike, NDArray],
imgs: NDArray,
disk: NDArray,
ell_params: Union[ArrayLike, NDArray, None] = None,
) -> None:
self.positions_vu = np.array(fitted_positions_vu)
self.imgs = imgs
self.disk = disk
self.global_lims = False

if ell_params is not None:
ell_params = np.array(ell_params)
self.ell_params = ell_params

self.curr_pos = 0

if self.ell_params is not None:
us = np.sort(self.positions_vu[1, :])
self.v_1, self.v_2 = fitting.Ellipse.predict_v(self.ell_params, us)

self.fig, self.axs = plt.subplots(1, 3, figsize=cm2inch([36, 12])) # , sharex=True, sharey=True
self.axs[2].imshow(self.disk)
self.axs[0].set_xlim(0, self.imgs.shape[-1])
self.axs[0].set_ylim(self.imgs.shape[-3], 0)
self.fig.tight_layout()
self.update()

self.fig.canvas.mpl_connect("key_press_event", self._key_event)
self.fig.canvas.mpl_connect("scroll_event", self._scroll_event)

def update(self) -> None:
self.curr_pos = self.curr_pos % self.imgs.shape[-2]

for img in self.axs[0].get_images():
img.remove()
x_lims = self.axs[0].get_xlim()
y_lims = self.axs[0].get_ylim()
self.axs[0].cla()
self.axs[0].set_xlim(x_lims[0], x_lims[1])
self.axs[0].set_ylim(y_lims[0], y_lims[1])

for img in self.axs[1].get_images():
img.remove()
self.axs[1].cla()

self.axs[0].plot(self.positions_vu[1, :], self.positions_vu[0, :], "bo-", markersize=4)
self.axs[0].scatter(self.positions_vu[1, self.curr_pos], self.positions_vu[0, self.curr_pos], c="r")

if self.ell_params is not None:
us = np.sort(self.positions_vu[1, :])
self.axs[0].plot(us, self.v_1, "g")
self.axs[0].plot(us, self.v_2, "g")
self.axs[0].grid()

if self.global_lims:
vmin = self.imgs.min()
vmax = self.imgs.max()
else:
vmin = self.imgs[:, self.curr_pos, :].min()
vmax = self.imgs[:, self.curr_pos, :].max()

img = self.axs[1].imshow(self.imgs[:, self.curr_pos, :], vmin=vmin, vmax=vmax)
self.axs[1].scatter(self.positions_vu[1, self.curr_pos], self.positions_vu[0, self.curr_pos], c="r")
self.axs[1].set_title(f"Range: [{vmin}, {vmax}]")
# plt.colorbar(im, ax=self.axs[1])
self.fig.canvas.draw()

def _key_event(self, evnt) -> None:
if evnt.key == "right":
self.curr_pos += 1
elif evnt.key == "left":
self.curr_pos -= 1
elif evnt.key == "up":
self.curr_pos += 1
elif evnt.key == "down":
self.curr_pos -= 1
elif evnt.key == "pageup":
self.curr_pos += 10
elif evnt.key == "pagedown":
self.curr_pos -= 10
elif evnt.key == "escape":
plt.close(self.fig)
elif evnt.key == "ctrl+l":
self.global_lims = not self.global_lims
else:
print(evnt.key)
return

self.update()

def _scroll_event(self, evnt) -> None:
if evnt.button == "up":
self.curr_pos += 1
elif evnt.button == "down":
self.curr_pos -= 1
else:
print(evnt.key)
return

self.update()
Loading

0 comments on commit 7c493ff

Please sign in to comment.