diff --git a/atomai/losses_metrics/vi_losses.py b/atomai/losses_metrics/vi_losses.py index f6510b17..b806ae84 100644 --- a/atomai/losses_metrics/vi_losses.py +++ b/atomai/losses_metrics/vi_losses.py @@ -89,6 +89,7 @@ def vae_loss(recon_loss: str, x: torch.Tensor, x_reconstr: torch.Tensor, *args: torch.Tensor, + **kwargs: List[float] ) -> torch.Tensor: """ Calculates ELBO @@ -98,9 +99,13 @@ def vae_loss(recon_loss: str, else: raise ValueError( "Pass mean and SD values of encoded distribution as args") + capacity = kwargs.get("capacity") + num_iter = kwargs.get("num_iter", 0) likelihood = -reconstruction_loss(recon_loss, in_dim, x, x_reconstr).mean() - kl_z = kld_normal(q_param).mean() - return likelihood - kl_z + kl_div = kld_normal(q_param).mean() + if capacity is not None: + kl_div = infocapacity(kl_div, capacity, num_iter=num_iter) + return likelihood - kl_div def rvae_loss(recon_loss: str, @@ -108,7 +113,8 @@ def rvae_loss(recon_loss: str, x: torch.Tensor, x_reconstr: torch.Tensor, *args: torch.Tensor, - **kwargs: float) -> torch.Tensor: + **kwargs: Union[List[float], float] + ) -> torch.Tensor: """ Calculates ELBO """ @@ -118,13 +124,16 @@ def rvae_loss(recon_loss: str, raise ValueError( "Pass mean and SD values of encoded distribution as args") phi_prior = kwargs.get("phi_prior", 0.1) - b1, b2 = kwargs.get("b1", 1), kwargs.get("b2", 1) + capacity = kwargs.get("capacity") + num_iter = kwargs.get("num_iter", 0) phi_logsd = z_logsd[:, 0] z_mean, z_logsd = z_mean[:, 1:], z_logsd[:, 1:] likelihood = -reconstruction_loss(recon_loss, in_dim, x, x_reconstr).mean() kl_rot = kld_rot(phi_prior, phi_logsd).mean() kl_z = kld_normal([z_mean, z_logsd]).mean() - kl_div = (b1*kl_z + b2 * kl_rot) + kl_div = (kl_z + kl_rot) + if capacity is not None: + kl_div = infocapacity(kl_div, capacity, num_iter=num_iter) return likelihood - kl_div @@ -145,8 +154,8 @@ def joint_vae_loss(recon_loss: str, "Pass continuous (mean, SD) and discrete (alphas) values" + "of encoded distributions as args") - cont_capacity = kwargs.get("cont_capacity", [0.0, 5.0, 25000, 30]) - disc_capacity = kwargs.get("disc_capacity", [0.0, 5.0, 25000, 30]) + cont_capacity = kwargs.get("cont_capacity", [5.0, 25000, 30]) + disc_capacity = kwargs.get("disc_capacity", [5.0, 25000, 30]) num_iter = kwargs.get("num_iter", 0) disc_dims = [a.size(1) for a in alphas] @@ -160,7 +169,7 @@ def joint_vae_loss(recon_loss: str, kl_disc_loss = torch.sum(torch.cat(kl_disc)) # Apply information capacity terms to contninuous and discrete channels - cargs = [kl_cont_loss, kl_disc_loss, cont_capacity, + cargs = [kl_cont_loss, cont_capacity, kl_disc_loss, disc_capacity, disc_dims, num_iter] cont_capacity_loss, disc_capacity_loss = infocapacity(*cargs) @@ -172,7 +181,8 @@ def joint_rvae_loss(recon_loss: str, x: torch.Tensor, x_reconstr: torch.Tensor, *args: torch.Tensor, - **kwargs: float) -> torch.Tensor: + **kwargs: Union[List, float, int] + ) -> torch.Tensor: """ Calculates joint ELBO for continuous and discrete variables """ @@ -184,9 +194,8 @@ def joint_rvae_loss(recon_loss: str, "of encoded distributions as args") phi_prior = kwargs.get("phi_prior", 0.1) - klrot_cap = kwargs.get("klrot_cap", True) - cont_capacity = kwargs.get("cont_capacity", [0.0, 5.0, 25000, 30]) - disc_capacity = kwargs.get("disc_capacity", [0.0, 5.0, 25000, 30]) + cont_capacity = kwargs.get("cont_capacity", [5.0, 25000, 30]) + disc_capacity = kwargs.get("disc_capacity", [5.0, 25000, 30]) num_iter = kwargs.get("num_iter", 0) # Calculate reconstruction loss term @@ -197,10 +206,7 @@ def joint_rvae_loss(recon_loss: str, z_mean, z_logsd = z_mean[:, 1:], z_logsd[:, 1:] # image content kl_rot = kld_rot(phi_prior, phi_logsd).mean() kl_z = kld_normal([z_mean, z_logsd]).mean() - if klrot_cap: - kl_cont_loss = kl_z + kl_rot - else: # no capacity limit on KL term associated with rotations - kl_cont_loss = kl_z + kl_cont_loss = kl_z + kl_rot # Calculate KL term for discrete latent variables disc_dims = [a.size(1) for a in alphas] @@ -208,47 +214,38 @@ def joint_rvae_loss(recon_loss: str, kl_disc_loss = torch.sum(torch.cat(kl_disc)) # Apply information capacity terms to contninuous and discrete channels - cargs = [kl_cont_loss, kl_disc_loss, cont_capacity, + cargs = [kl_cont_loss, cont_capacity, kl_disc_loss, disc_capacity, disc_dims, num_iter] cont_capacity_loss, disc_capacity_loss = infocapacity(*cargs) - if not klrot_cap: - cont_capacity_loss = cont_capacity_loss + kl_rot return likelihood - cont_capacity_loss - disc_capacity_loss def infocapacity(kl_cont_loss: torch.Tensor, - kl_disc_loss: torch.Tensor, cont_capacity: List[float], - disc_capacity: List[float], - disc_dims: List[int], - num_iter: int) -> torch.Tensor: - """ - Controls information capacity of the continuous and discrete loss - (based on https://arxiv.org/pdf/1804.00104.pdf & - https://github.com/Schlumberger/joint-vae/blob/master/jointvae/training.py) - """ - # Linearly increase capacity of continuous channels - cont_min, cont_max, cont_num_iters, cont_gamma = cont_capacity - # Increase continuous capacity without exceeding cont_max - cont_cap_current = (cont_max - cont_min) * num_iter - cont_cap_current = cont_cap_current / float(cont_num_iters) + cont_min - cont_cap_current = min(cont_cap_current, cont_max) - # Calculate continuous capacity loss - cont_capacity_loss = cont_gamma*torch.abs(cont_cap_current - kl_cont_loss) - - # Linearly increase capacity of discrete channels - disc_min, disc_max, disc_num_iters, disc_gamma = disc_capacity - # Increase discrete capacity without exceeding disc_max or theoretical - # maximum (i.e. sum of log of dimension of each discrete variable) - disc_cap_current = (disc_max - disc_min) * num_iter - disc_cap_current = disc_cap_current / float(disc_num_iters) + disc_min - disc_cap_current = min(disc_cap_current, disc_max) - # Require float conversion here to not end up with numpy float + kl_disc_loss: Optional[torch.Tensor] = None, + disc_capacity: Optional[List[float]] = None, + disc_dims: Optional[List[int]] = None, + num_iter: int = 0 + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """ + Controls information capacity of KL term(s) + (see https://arxiv.org/pdf/1804.03599.pdf & https://arxiv.org/pdf/1804.00104.pdf) + """ + # Increase capacity of continuous latent channel + cont_max, cont_num_iters, cont_gamma = cont_capacity + cont_cap = cont_max * (num_iter / float(cont_num_iters)) + cont_cap = min(cont_cap, cont_max) + # Calculate continuous KL term + cont_capacity_loss = cont_gamma * torch.abs(kl_cont_loss - cont_cap) + if kl_disc_loss is None: + return cont_capacity_loss + # Increase capacity of discrete latent channel + disc_max, disc_num_iters, disc_gamma = disc_capacity disc_theory_max = sum([float(np.log(d)) for d in disc_dims]) - disc_cap_current = min(disc_cap_current, disc_theory_max) - # Calculate discrete capacity loss - disc_capacity_loss = disc_gamma*torch.abs(disc_cap_current - kl_disc_loss) + disc_cap = disc_max * (num_iter / float(disc_num_iters)) + disc_cap = min(disc_cap, disc_max, disc_theory_max) + # Calculate discrete KL term + disc_capacity_loss = disc_gamma * torch.abs(disc_cap - kl_disc_loss) return cont_capacity_loss, disc_capacity_loss - diff --git a/atomai/models/dgm/jrvae.py b/atomai/models/dgm/jrvae.py index eb8c9108..e6d9d064 100644 --- a/atomai/models/dgm/jrvae.py +++ b/atomai/models/dgm/jrvae.py @@ -36,7 +36,8 @@ class jrVAE(BaseVAE): List specifying dimensionalities of discrete (Gumbel-Softmax) latent variables associated with image content nb_classes: - Number of classes for class-conditional VAE + Number of classes for class-conditional VAE. + (leave it at 0 to learn discrete latent reprenetations) translation: account for xy shifts of image content (Default: True) seed: @@ -87,7 +88,6 @@ def __init__(self, self.translation = translation self.dx_prior = None self.phi_prior = None - self.anneal_dict = None self.kdict_ = dc(kwargs) self.kdict_["num_iter"] = 0 @@ -166,13 +166,13 @@ def fit(self, 3D or 4D stack of training images with dimensions (n_images, height, width) for grayscale data or or (n_images, height, width, channels) for multi-channel data - X_test: - 3D or 4D stack of test images with the same dimensions - as for the X_train (Default: None) y_train: Vector with labels of dimension (n_images,), where n_images is a number of training images - y_train: + X_test: + 3D or 4D stack of test images with the same dimensions + as for the X_train (Default: None) + y_test: Vector with labels of dimension (n_images,), where n_images is a number of test images loss: @@ -184,28 +184,23 @@ def fit(self, **temperature (float): Relaxation parameter for Gumbel-Softmax distribution **cont_capacity (list): - List containing (min_capacity, max_capacity, num_iters, gamma_z) - parameters to control the capacity of the continuous latent - channels. Default values: [0.0, 5.0, 25000, 30]. - Based on https://arxiv.org/abs/1804.00104 + List containing (max_capacity, num_iters, gamma) parameters + to control the capacity of the continuous latent channel. + Default values: [5.0, 25000, 30]. + Based on https://arxiv.org/pdf/1804.03599.pdf & https://arxiv.org/abs/1804.00104 **disc_capacity (list): - List containing (min_capacity, max_capacity, num_iters, gamma_c) - parameters to control the capacity of the discrete latent channels. - Default values: [0.0, 5.0, 25000, 30]. - Based on https://arxiv.org/abs/1804.00104 - **klrot_cap (bool): - Do not control capacity of KL term associated - with rotations of coordinate grid + List containing (max_capacity, num_iters, gamma) parameters + to control the capacity of the discrete latent channel(s). + Default values: [5.0, 25000, 30]. + Based on https://arxiv.org/pdf/1804.03599.pdf & https://arxiv.org/abs/1804.00104 **filename (str): - file path for saving model aftereach training cycle ("epoch") + file path for saving model after each training cycle ("epoch") """ self._check_inputs(X_train, y_train, X_test, y_test) self.dx_prior = kwargs.get("translation_prior", 0.1) self.kdict_["phi_prior"] = kwargs.get("rotation_prior", 0.1) - self.anneal_dict = kwargs.get("anneal_dict") for k, v in kwargs.items(): - if k in ["cont_capacity", "disc_capacity", - "temperature", "klrot_cap"]: + if k in ["cont_capacity", "disc_capacity", "temperature"]: self.kdict_[k] = v self.compile_trainer( (X_train, y_train), (X_test, y_test), **kwargs) diff --git a/atomai/models/dgm/jvae.py b/atomai/models/dgm/jvae.py index ce640c95..483d4224 100644 --- a/atomai/models/dgm/jvae.py +++ b/atomai/models/dgm/jvae.py @@ -36,6 +36,7 @@ class jVAE(BaseVAE): latent variables associated with image content nb_classes: Number of classes for class-conditional VAE + (leave it at 0 to learn discrete latent reprenetations) seed: seed for torch and numpy (pseudo-)random numbers generators **conv_encoder (bool): @@ -148,27 +149,27 @@ def fit(self, (n_images, height, width) for grayscale data or or (n_images, height, width, channels) for multi-channel data. For spectra, 2D stack of spectra with dimensions (length,) - X_test: - 3D or 4D stack of test images or 2D stack of spectra with - the same dimensions as for the X_train (Default: None) y_train: Vector with labels of dimension (n_images,), where n_images is a number of training images/spectra - y_train: + X_test: + 3D or 4D stack of test images or 2D stack of spectra with + the same dimensions as for the X_train (Default: None) + y_test: Vector with labels of dimension (n_images,), where n_images is a number of test images/spectra loss: reconstruction loss function, "ce" or "mse" (Default: "mse") **cont_capacity (list): - List containing (min_capacity, max_capacity, num_iters, gamma_z) - parameters to control the capacity of the continuous latent - channels. Default values: [0.0, 5.0, 25000, 30]. - Based on https://arxiv.org/abs/1804.00104 + List containing (max_capacity, num_iters, gamma) parameters + to control the capacity of the continuous latent channel. + Default values: [5.0, 25000, 30]. + Based on https://arxiv.org/pdf/1804.03599.pdf & https://arxiv.org/abs/1804.00104 **disc_capacity (list): - List containing (min_capacity, max_capacity, num_iters, gamma_c) - parameters to control the capacity of the discrete latent channels. - Default values: [0.0, 5.0, 25000, 30]. - Based on https://arxiv.org/abs/1804.00104 + List containing (max_capacity, num_iters, gamma) parameters + to control the capacity of the discrete latent channel(s). + Default values: [5.0, 25000, 30]. + Based on https://arxiv.org/pdf/1804.03599.pdf & https://arxiv.org/abs/1804.00104 **filename (str): file path for saving model aftereach training cycle ("epoch") """ diff --git a/atomai/models/dgm/rvae.py b/atomai/models/dgm/rvae.py index 6f3f7a40..87ce5bfa 100644 --- a/atomai/models/dgm/rvae.py +++ b/atomai/models/dgm/rvae.py @@ -8,14 +8,13 @@ Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) """ - -from typing import Optional, Union +from copy import deepcopy as dc +from typing import Optional, Union, List import numpy as np import torch from ...losses_metrics import rvae_loss - from ...utils import set_train_rng, to_onehot, transform_coordinates from .vae import BaseVAE @@ -57,7 +56,7 @@ class rVAE(BaseVAE): Example: - >>> input_dim = (28, 28) # intput dimensions + >>> input_dim = (28, 28) # input dimensions >>> # Intitialize model >>> rvae = aoi.models.rVAE(input_dim) >>> # Train @@ -94,13 +93,14 @@ def __init__(self, self.translation = translation self.dx_prior = None self.phi_prior = None - self.anneal_dict = None + self.kdict_ = dc(kwargs) + self.kdict_["num_iter"] = 0 def elbo_fn(self, x: torch.Tensor, x_reconstr: torch.Tensor, *args: torch.Tensor, - **kwargs: float + **kwargs: Union[List, float, int] ) -> torch.Tensor: """ Computes ELBO @@ -121,6 +121,7 @@ def forward_compute_elbo(self, z_mean, z_logsd = self.encoder_net(x) else: z_mean, z_logsd = self.encoder_net(x) + self.kdict_["num_iter"] += 1 z_sd = torch.exp(z_logsd) z = self.reparameterize(z_mean, z_sd) phi = z[:, 0] # angle @@ -142,16 +143,8 @@ def forward_compute_elbo(self, x_reconstr = self.decoder_net(x_coord_, z) else: x_reconstr = self.decoder_net(x_coord_, z) - # KL annealing terms - b1 = b2 = 1 - if isinstance(self.anneal_dict, dict): - e_ = self.current_epoch - b1 = self.anneal_dict["kl_im"] - b2 = self.anneal_dict["kl_rot"] - b1 = b1[-1] if len(b1) < e_ + 1 else b1[e_] - b2 = b2[-1] if len(b2) < e_ + 1 else b2[e_] - return self.elbo_fn(x, x_reconstr, z_mean, z_logsd, - phi_prior=self.phi_prior, b1=b1, b2=b2) + + return self.elbo_fn(x, x_reconstr, z_mean, z_logsd, **self.kdict_) def fit(self, X_train: Union[np.ndarray, torch.Tensor], @@ -168,13 +161,13 @@ def fit(self, 3D or 4D stack of training images with dimensions (n_images, height, width) for grayscale data or or (n_images, height, width, channels) for multi-channel data - X_test: - 3D or 4D stack of test images with the same dimensions - as for the X_train (Default: None) y_train: Vector with labels of dimension (n_images,), where n_images is a number of training images - y_train: + X_test: + 3D or 4D stack of test images with the same dimensions + as for the X_train (Default: None) + y_test: Vector with labels of dimension (n_images,), where n_images is a number of test images loss: @@ -183,6 +176,10 @@ def fit(self, translation prior **rotation_prior (float): rotational prior + **capacity (list): + List containing (max_capacity, num_iters, gamma) parameters + to control the capacity of the latent channel. + Based on https://arxiv.org/pdf/1804.03599.pdf **filename (str): file path for saving model aftereach training cycle ("epoch") **recording (bool): @@ -190,8 +187,10 @@ def fit(self, """ self._check_inputs(X_train, y_train, X_test, y_test) self.dx_prior = kwargs.get("translation_prior", 0.1) - self.phi_prior = kwargs.get("rotation_prior", 0.1) - self.anneal_dict = kwargs.get("anneal_dict") + self.kdict_["phi_prior"] = kwargs.get("rotation_prior", 0.1) + for k, v in kwargs.items(): + if k in ["capacity"]: + self.kdict_[k] = v self.compile_trainer( (X_train, y_train), (X_test, y_test), **kwargs) self.loss = loss # this part needs to be handled better @@ -208,8 +207,13 @@ def fit(self, elbo_epoch_test = self.evaluate_model() self.loss_history["test_loss"].append(elbo_epoch_test) self.print_statistics(e) + self.update_metadict() if self.recording and self.z_dim in [3, 5]: self.manifold2d(savefig=True, filename=str(e)) self.save_model(self.filename) if self.recording and self.z_dim in [3, 5]: self.visualize_manifold_learning("./vae_learning") + + def update_metadict(self): + self.metadict["num_epochs"] = self.current_epoch + self.metadict["num_iter"] = self.kdict_["num_iter"] diff --git a/atomai/models/dgm/vae.py b/atomai/models/dgm/vae.py index 225bd236..4ebd679e 100644 --- a/atomai/models/dgm/vae.py +++ b/atomai/models/dgm/vae.py @@ -8,6 +8,7 @@ """ import os +from copy import deepcopy as dc from typing import Dict, List, Optional, Tuple, Union import matplotlib.pyplot as plt @@ -17,7 +18,6 @@ from torchvision.utils import make_grid from ...losses_metrics import vae_loss - from ...nets import init_VAE_nets from ...trainers import viBaseTrainer from ...utils import (crop_borders, extract_subimages, get_coord_grid, @@ -408,17 +408,17 @@ def manifold2d(self, **kwargs: Union[int, List, str, bool]) -> None: # use torc elif len(self.in_dim) == 3: figure = np.zeros((self.in_dim[0] * d, self.in_dim[1] * d, self.in_dim[-1])) if l1 and l2: - grid_x = np.linspace(l1[0], l1[1], d) + grid_x = np.linspace(l1[1], l1[0], d) grid_y = np.linspace(l2[0], l2[1], d) else: - grid_x = norm.ppf(np.linspace(0.05, 0.95, d)) + grid_x = norm.ppf(np.linspace(0.95, 0.05, d)) grid_y = norm.ppf(np.linspace(0.05, 0.95, d)) if self.discrete_dim: z_disc = np.zeros((sum(self.discrete_dim)))[None] z_disc[:, kwargs.get("disc_idx", 0)] = 1 - for i, yi in enumerate(grid_x): - for j, xi in enumerate(grid_y): + for i, xi in enumerate(grid_x): + for j, yi in enumerate(grid_y): z_sample = np.array([[xi, yi]]) if self.discrete_dim: z_sample = np.concatenate((z_sample, z_disc), -1) @@ -432,7 +432,10 @@ def manifold2d(self, **kwargs: Union[int, List, str, bool]) -> None: # use torc figure = (figure - figure.min()) / figure.ptp() fig, ax = plt.subplots(figsize=(10, 10)) - ax.imshow(figure, cmap=cmap, origin=kwargs.get("origin", "lower")) + ax.imshow(figure, cmap=cmap, origin=kwargs.get("origin", "lower"), + extent=[grid_x.min(), grid_x.max(), grid_y.min(), grid_y.max()]) + ax.set_xlabel("$z_1$") + ax.set_ylabel("$z_2$") draw_grid = kwargs.get("draw_grid") if draw_grid: major_ticks_x = np.arange(0, d * self.in_dim[0], self.in_dim[0]) @@ -440,6 +443,9 @@ def manifold2d(self, **kwargs: Union[int, List, str, bool]) -> None: # use torc ax.set_xticks(major_ticks_x) ax.set_yticks(major_ticks_y) ax.grid(which='major', alpha=0.6) + for item in ([ax.xaxis.label, ax.yaxis.label] + + ax.get_xticklabels() + ax.get_yticklabels()): + item.set_fontsize(18) if not kwargs.get("savefig"): plt.show() else: @@ -447,8 +453,6 @@ def manifold2d(self, **kwargs: Union[int, List, str, bool]) -> None: # use torc fname = kwargs.get("filename", "manifold_2d") if not os.path.exists(savedir): os.makedirs(savedir) - ax.set_xticklabels([]) - ax.set_yticklabels([]) fig.savefig(os.path.join(savedir, '{}.png'.format(fname))) plt.close(fig) @@ -503,6 +507,11 @@ def manifold_traversal(self, cont_idx: int, plt.figure(figsize=(12, 12)) plt.imshow(grid, cmap='gnuplot', origin=kwargs.get("origin", "lower")) + plt.xlabel("$z_{cont}$", fontsize=18) + plt.ylabel("$z_{disc}$", fontsize=18) + plt.xticks([]) + plt.yticks([]) + plt.show() return grid @classmethod @@ -632,13 +641,16 @@ def __init__(self, ) -> None: super(VAE, self).__init__(in_dim, latent_dim, nb_classes, 0, **kwargs) set_train_rng(seed) + self.kdict_ = dc(kwargs) + self.kdict_["num_iter"] = 0 def elbo_fn(self, x: torch.Tensor, x_reconstr: torch.Tensor, - *args: torch.Tensor) -> torch.Tensor: + *args: torch.Tensor, + **kwargs) -> torch.Tensor: """ Calculates ELBO """ - return vae_loss(self.loss, self.in_dim, x, x_reconstr, *args) + return vae_loss(self.loss, self.in_dim, x, x_reconstr, *args, **kwargs) def forward_compute_elbo(self, x: torch.Tensor, @@ -654,6 +666,7 @@ def forward_compute_elbo(self, z_mean, z_logsd = self.encoder_net(x) else: z_mean, z_logsd = self.encoder_net(x) + self.kdict_["num_iter"] += 1 z_sd = torch.exp(z_logsd) z = self.reparameterize(z_mean, z_sd) if y is not None: @@ -665,7 +678,7 @@ def forward_compute_elbo(self, else: x_reconstr = self.decoder_net(z) - return self.elbo_fn(x, x_reconstr, z_mean, z_logsd) + return self.elbo_fn(x, x_reconstr, z_mean, z_logsd, **self.kdict_) def fit(self, X_train: Union[np.ndarray, torch.Tensor], @@ -683,21 +696,28 @@ def fit(self, (n_images, height, width) for grayscale data or or (n_images, height, width, channels) for multi-channel data. For spectra, 2D stack of spectra with dimensions (length,) - X_test: - 3D or 4D stack of test images or 2D stack of spectra with - the same dimensions as for the X_train (Default: None) y_train: Vector with labels of dimension (n_images,), where n_images is a number of training images/spectra - y_train: + X_test: + 3D or 4D stack of test images or 2D stack of spectra with + the same dimensions as for the X_train (Default: None) + y_test: Vector with labels of dimension (n_images,), where n_images is a number of test images/spectra loss: reconstruction loss function, "ce" or "mse" (Default: "mse") + **capacity (list): + List containing (max_capacity, num_iters, gamma) parameters + to control the capacity of the latent channel. + Based on https://arxiv.org/pdf/1804.03599.pdf **filename (str): file path for saving model aftereach training cycle ("epoch") """ self._check_inputs(X_train, y_train, X_test, y_test) + for k, v in kwargs.items(): + if k in ["capacity"]: + self.kdict_[k] = v self.compile_trainer( (X_train, y_train), (X_test, y_test), **kwargs) self.loss = loss # this part needs to be handled better @@ -712,5 +732,10 @@ def fit(self, elbo_epoch_test = self.evaluate_model() self.loss_history["test_loss"].append(elbo_epoch_test) self.print_statistics(e) + self.update_metadict() self.save_model(self.filename) return + + def update_metadict(self): + self.metadict["num_epochs"] = self.current_epoch + self.metadict["num_iter"] = self.kdict_["num_iter"] diff --git a/atomai/models/loaders.py b/atomai/models/loaders.py index 7af23f99..6f875418 100644 --- a/atomai/models/loaders.py +++ b/atomai/models/loaders.py @@ -67,10 +67,11 @@ def load_seg_model(meta_dict: Dict[str, torch.Tensor]) -> Type[Segmentor]: model_name = meta_dict.pop("model") nb_classes = meta_dict.pop("nb_classes") weights = meta_dict.pop("weights") - optimizer = meta_dict.pop("optimizer") model = Segmentor(model_name, nb_classes, **meta_dict) model.net.load_state_dict(weights) - model.optimizer = optimizer + if "optimizer" in meta_dict.keys(): + optimizer = meta_dict.pop("optimizer") + model.optimizer = optimizer model.net.eval() return model diff --git a/atomai/nets/ed.py b/atomai/nets/ed.py index 3d814ba9..b0049482 100644 --- a/atomai/nets/ed.py +++ b/atomai/nets/ed.py @@ -489,7 +489,6 @@ def __init__(self, latent_dim: int, num_layers: int = 2, hidden_dim: int = 32, - num_classes: int = 0, # num_classes is redundant (just pass latent_dim = latent_dim + nb_classes in init_vae_nets) **kwargs: float) -> None: """ Initializes network parameters @@ -502,7 +501,7 @@ def __init__(self, dim = 2 if len(out_dim) > 1 else 1 c = out_dim[-1] if len(out_dim) > 2 else 1 self.fc_linear = nn.Linear( - latent_dim + num_classes, hidden_dim * np.product(out_dim[:2]), + latent_dim, hidden_dim * np.product(out_dim[:2]), bias=False) self.reshape_ = (hidden_dim, *out_dim[:2]) self.decoder = ConvBlock( @@ -549,8 +548,7 @@ def __init__(self, latent_dim: int, num_layers: int = 2, hidden_dim: int = 32, - num_classes: int = 0 - ) -> None: # num_classes is redundant (just pass latent_dim = latent_dim + nb_classes in init_vae_nets) + ) -> None: """ Initializes network parameters """ @@ -562,7 +560,7 @@ def __init__(self, c = out_dim[-1] if len(out_dim) > 2 else 1 decoder = [] for i in range(num_layers): - hidden_dim_ = latent_dim + num_classes if i == 0 else hidden_dim + hidden_dim_ = latent_dim if i == 0 else hidden_dim decoder.extend([nn.Linear(hidden_dim_, hidden_dim), nn.Tanh()]) self.decoder = nn.Sequential(*decoder) self.out = nn.Linear(hidden_dim, np.product(out_dim)) @@ -605,7 +603,6 @@ def __init__(self, num_layers: int, hidden_dim: int, skip: bool = False, - num_classes: int = 0 # num_classes is redundant (just pass latent_dim = latent_dim + nb_classes in init_vae_nets) ) -> None: """ Initializes network parameters @@ -619,7 +616,7 @@ def __init__(self, self.reshape_ = (out_dim[0], out_dim[1], c) self.skip = skip self.coord_latent = coord_latent( - latent_dim+num_classes, hidden_dim, not skip) + latent_dim, hidden_dim, not skip) fc_decoder = [] for i in range(num_layers): fc_decoder.extend([nn.Linear(hidden_dim, hidden_dim), nn.Tanh()]) @@ -749,16 +746,17 @@ def init_VAE_nets(in_dim: Tuple[int], discrete_dim_ = 0 if discrete_dim: discrete_dim_ = sum(discrete_dim) + nb_classes_ = nb_classes if discrete_dim_ == 0 else 0 if not coord: dnet = convDecoderNet if conv_d else fcDecoderNet decoder_net = dnet( - in_dim, latent_dim+discrete_dim_, numlayers_d, numhidden_d, - nb_classes) + in_dim, latent_dim+discrete_dim_+nb_classes_, + numlayers_d, numhidden_d) else: decoder_net = rDecoderNet( - in_dim, latent_dim+discrete_dim_, numlayers_d, numhidden_d, - skip, nb_classes) + in_dim, latent_dim+discrete_dim_+nb_classes_, + numlayers_d, numhidden_d, skip) if not discrete_dim: enet = convEncoderNet if conv_e else fcEncoderNet encoder_net = enet( diff --git a/atomai/trainers/etrainer.py b/atomai/trainers/etrainer.py index fd6ba5f8..6b65b520 100644 --- a/atomai/trainers/etrainer.py +++ b/atomai/trainers/etrainer.py @@ -287,14 +287,15 @@ def preprocess_train_data(self, tor = lambda x: torch.from_numpy(x) return tor(X), tor(y), tor(X_), tor(y_) - def save_ensemble_metadict(self) -> None: + def save_ensemble_metadict(self, filename: str = None) -> None: """ Saves meta dictionary with ensemble weights and key information about model's structure (needed to load it back) to disk """ + fname = self.filename if filename is None else filename ensemble_metadict = dc(self.meta_state_dict) ensemble_metadict["weights"] = self.ensemble_state_dict - torch.save(ensemble_metadict, self.filename + "_ensemble_metadict.tar") + torch.save(ensemble_metadict, fname + "_ensemble_metadict.tar") class EnsembleTrainer(BaseEnsembleTrainer):