Skip to content

Commit

Permalink
Merge pull request #10 from ziatdinovmax/master
Browse files Browse the repository at this point in the history
Improvements and bug fixes
  • Loading branch information
ziatdinovmax authored Mar 10, 2021
2 parents 7fb6dcd + ea38f75 commit 143a707
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 134 deletions.
95 changes: 46 additions & 49 deletions atomai/losses_metrics/vi_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def vae_loss(recon_loss: str,
x: torch.Tensor,
x_reconstr: torch.Tensor,
*args: torch.Tensor,
**kwargs: List[float]
) -> torch.Tensor:
"""
Calculates ELBO
Expand All @@ -98,17 +99,22 @@ def vae_loss(recon_loss: str,
else:
raise ValueError(
"Pass mean and SD values of encoded distribution as args")
capacity = kwargs.get("capacity")
num_iter = kwargs.get("num_iter", 0)
likelihood = -reconstruction_loss(recon_loss, in_dim, x, x_reconstr).mean()
kl_z = kld_normal(q_param).mean()
return likelihood - kl_z
kl_div = kld_normal(q_param).mean()
if capacity is not None:
kl_div = infocapacity(kl_div, capacity, num_iter=num_iter)
return likelihood - kl_div


def rvae_loss(recon_loss: str,
in_dim: Tuple[int],
x: torch.Tensor,
x_reconstr: torch.Tensor,
*args: torch.Tensor,
**kwargs: float) -> torch.Tensor:
**kwargs: Union[List[float], float]
) -> torch.Tensor:
"""
Calculates ELBO
"""
Expand All @@ -118,13 +124,16 @@ def rvae_loss(recon_loss: str,
raise ValueError(
"Pass mean and SD values of encoded distribution as args")
phi_prior = kwargs.get("phi_prior", 0.1)
b1, b2 = kwargs.get("b1", 1), kwargs.get("b2", 1)
capacity = kwargs.get("capacity")
num_iter = kwargs.get("num_iter", 0)
phi_logsd = z_logsd[:, 0]
z_mean, z_logsd = z_mean[:, 1:], z_logsd[:, 1:]
likelihood = -reconstruction_loss(recon_loss, in_dim, x, x_reconstr).mean()
kl_rot = kld_rot(phi_prior, phi_logsd).mean()
kl_z = kld_normal([z_mean, z_logsd]).mean()
kl_div = (b1*kl_z + b2 * kl_rot)
kl_div = (kl_z + kl_rot)
if capacity is not None:
kl_div = infocapacity(kl_div, capacity, num_iter=num_iter)
return likelihood - kl_div


Expand All @@ -145,8 +154,8 @@ def joint_vae_loss(recon_loss: str,
"Pass continuous (mean, SD) and discrete (alphas) values" +
"of encoded distributions as args")

cont_capacity = kwargs.get("cont_capacity", [0.0, 5.0, 25000, 30])
disc_capacity = kwargs.get("disc_capacity", [0.0, 5.0, 25000, 30])
cont_capacity = kwargs.get("cont_capacity", [5.0, 25000, 30])
disc_capacity = kwargs.get("disc_capacity", [5.0, 25000, 30])
num_iter = kwargs.get("num_iter", 0)
disc_dims = [a.size(1) for a in alphas]

Expand All @@ -160,7 +169,7 @@ def joint_vae_loss(recon_loss: str,
kl_disc_loss = torch.sum(torch.cat(kl_disc))

# Apply information capacity terms to contninuous and discrete channels
cargs = [kl_cont_loss, kl_disc_loss, cont_capacity,
cargs = [kl_cont_loss, cont_capacity, kl_disc_loss,
disc_capacity, disc_dims, num_iter]
cont_capacity_loss, disc_capacity_loss = infocapacity(*cargs)

Expand All @@ -172,7 +181,8 @@ def joint_rvae_loss(recon_loss: str,
x: torch.Tensor,
x_reconstr: torch.Tensor,
*args: torch.Tensor,
**kwargs: float) -> torch.Tensor:
**kwargs: Union[List, float, int]
) -> torch.Tensor:
"""
Calculates joint ELBO for continuous and discrete variables
"""
Expand All @@ -184,9 +194,8 @@ def joint_rvae_loss(recon_loss: str,
"of encoded distributions as args")

phi_prior = kwargs.get("phi_prior", 0.1)
klrot_cap = kwargs.get("klrot_cap", True)
cont_capacity = kwargs.get("cont_capacity", [0.0, 5.0, 25000, 30])
disc_capacity = kwargs.get("disc_capacity", [0.0, 5.0, 25000, 30])
cont_capacity = kwargs.get("cont_capacity", [5.0, 25000, 30])
disc_capacity = kwargs.get("disc_capacity", [5.0, 25000, 30])
num_iter = kwargs.get("num_iter", 0)

# Calculate reconstruction loss term
Expand All @@ -197,58 +206,46 @@ def joint_rvae_loss(recon_loss: str,
z_mean, z_logsd = z_mean[:, 1:], z_logsd[:, 1:] # image content
kl_rot = kld_rot(phi_prior, phi_logsd).mean()
kl_z = kld_normal([z_mean, z_logsd]).mean()
if klrot_cap:
kl_cont_loss = kl_z + kl_rot
else: # no capacity limit on KL term associated with rotations
kl_cont_loss = kl_z
kl_cont_loss = kl_z + kl_rot

# Calculate KL term for discrete latent variables
disc_dims = [a.size(1) for a in alphas]
kl_disc = [kld_discrete(alpha) for alpha in alphas]
kl_disc_loss = torch.sum(torch.cat(kl_disc))

# Apply information capacity terms to contninuous and discrete channels
cargs = [kl_cont_loss, kl_disc_loss, cont_capacity,
cargs = [kl_cont_loss, cont_capacity, kl_disc_loss,
disc_capacity, disc_dims, num_iter]
cont_capacity_loss, disc_capacity_loss = infocapacity(*cargs)
if not klrot_cap:
cont_capacity_loss = cont_capacity_loss + kl_rot

return likelihood - cont_capacity_loss - disc_capacity_loss


def infocapacity(kl_cont_loss: torch.Tensor,
kl_disc_loss: torch.Tensor,
cont_capacity: List[float],
disc_capacity: List[float],
disc_dims: List[int],
num_iter: int) -> torch.Tensor:
"""
Controls information capacity of the continuous and discrete loss
(based on https://arxiv.org/pdf/1804.00104.pdf &
https://github.com/Schlumberger/joint-vae/blob/master/jointvae/training.py)
"""
# Linearly increase capacity of continuous channels
cont_min, cont_max, cont_num_iters, cont_gamma = cont_capacity
# Increase continuous capacity without exceeding cont_max
cont_cap_current = (cont_max - cont_min) * num_iter
cont_cap_current = cont_cap_current / float(cont_num_iters) + cont_min
cont_cap_current = min(cont_cap_current, cont_max)
# Calculate continuous capacity loss
cont_capacity_loss = cont_gamma*torch.abs(cont_cap_current - kl_cont_loss)

# Linearly increase capacity of discrete channels
disc_min, disc_max, disc_num_iters, disc_gamma = disc_capacity
# Increase discrete capacity without exceeding disc_max or theoretical
# maximum (i.e. sum of log of dimension of each discrete variable)
disc_cap_current = (disc_max - disc_min) * num_iter
disc_cap_current = disc_cap_current / float(disc_num_iters) + disc_min
disc_cap_current = min(disc_cap_current, disc_max)
# Require float conversion here to not end up with numpy float
kl_disc_loss: Optional[torch.Tensor] = None,
disc_capacity: Optional[List[float]] = None,
disc_dims: Optional[List[int]] = None,
num_iter: int = 0
) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
"""
Controls information capacity of KL term(s)
(see https://arxiv.org/pdf/1804.03599.pdf & https://arxiv.org/pdf/1804.00104.pdf)
"""
# Increase capacity of continuous latent channel
cont_max, cont_num_iters, cont_gamma = cont_capacity
cont_cap = cont_max * (num_iter / float(cont_num_iters))
cont_cap = min(cont_cap, cont_max)
# Calculate continuous KL term
cont_capacity_loss = cont_gamma * torch.abs(kl_cont_loss - cont_cap)
if kl_disc_loss is None:
return cont_capacity_loss
# Increase capacity of discrete latent channel
disc_max, disc_num_iters, disc_gamma = disc_capacity
disc_theory_max = sum([float(np.log(d)) for d in disc_dims])
disc_cap_current = min(disc_cap_current, disc_theory_max)
# Calculate discrete capacity loss
disc_capacity_loss = disc_gamma*torch.abs(disc_cap_current - kl_disc_loss)
disc_cap = disc_max * (num_iter / float(disc_num_iters))
disc_cap = min(disc_cap, disc_max, disc_theory_max)
# Calculate discrete KL term
disc_capacity_loss = disc_gamma * torch.abs(disc_cap - kl_disc_loss)

return cont_capacity_loss, disc_capacity_loss

37 changes: 16 additions & 21 deletions atomai/models/dgm/jrvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class jrVAE(BaseVAE):
List specifying dimensionalities of discrete (Gumbel-Softmax)
latent variables associated with image content
nb_classes:
Number of classes for class-conditional VAE
Number of classes for class-conditional VAE.
(leave it at 0 to learn discrete latent reprenetations)
translation:
account for xy shifts of image content (Default: True)
seed:
Expand Down Expand Up @@ -87,7 +88,6 @@ def __init__(self,
self.translation = translation
self.dx_prior = None
self.phi_prior = None
self.anneal_dict = None
self.kdict_ = dc(kwargs)
self.kdict_["num_iter"] = 0

Expand Down Expand Up @@ -166,13 +166,13 @@ def fit(self,
3D or 4D stack of training images with dimensions
(n_images, height, width) for grayscale data or
or (n_images, height, width, channels) for multi-channel data
X_test:
3D or 4D stack of test images with the same dimensions
as for the X_train (Default: None)
y_train:
Vector with labels of dimension (n_images,), where n_images
is a number of training images
y_train:
X_test:
3D or 4D stack of test images with the same dimensions
as for the X_train (Default: None)
y_test:
Vector with labels of dimension (n_images,), where n_images
is a number of test images
loss:
Expand All @@ -184,28 +184,23 @@ def fit(self,
**temperature (float):
Relaxation parameter for Gumbel-Softmax distribution
**cont_capacity (list):
List containing (min_capacity, max_capacity, num_iters, gamma_z)
parameters to control the capacity of the continuous latent
channels. Default values: [0.0, 5.0, 25000, 30].
Based on https://arxiv.org/abs/1804.00104
List containing (max_capacity, num_iters, gamma) parameters
to control the capacity of the continuous latent channel.
Default values: [5.0, 25000, 30].
Based on https://arxiv.org/pdf/1804.03599.pdf & https://arxiv.org/abs/1804.00104
**disc_capacity (list):
List containing (min_capacity, max_capacity, num_iters, gamma_c)
parameters to control the capacity of the discrete latent channels.
Default values: [0.0, 5.0, 25000, 30].
Based on https://arxiv.org/abs/1804.00104
**klrot_cap (bool):
Do not control capacity of KL term associated
with rotations of coordinate grid
List containing (max_capacity, num_iters, gamma) parameters
to control the capacity of the discrete latent channel(s).
Default values: [5.0, 25000, 30].
Based on https://arxiv.org/pdf/1804.03599.pdf & https://arxiv.org/abs/1804.00104
**filename (str):
file path for saving model aftereach training cycle ("epoch")
file path for saving model after each training cycle ("epoch")
"""
self._check_inputs(X_train, y_train, X_test, y_test)
self.dx_prior = kwargs.get("translation_prior", 0.1)
self.kdict_["phi_prior"] = kwargs.get("rotation_prior", 0.1)
self.anneal_dict = kwargs.get("anneal_dict")
for k, v in kwargs.items():
if k in ["cont_capacity", "disc_capacity",
"temperature", "klrot_cap"]:
if k in ["cont_capacity", "disc_capacity", "temperature"]:
self.kdict_[k] = v
self.compile_trainer(
(X_train, y_train), (X_test, y_test), **kwargs)
Expand Down
25 changes: 13 additions & 12 deletions atomai/models/dgm/jvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class jVAE(BaseVAE):
latent variables associated with image content
nb_classes:
Number of classes for class-conditional VAE
(leave it at 0 to learn discrete latent reprenetations)
seed:
seed for torch and numpy (pseudo-)random numbers generators
**conv_encoder (bool):
Expand Down Expand Up @@ -148,27 +149,27 @@ def fit(self,
(n_images, height, width) for grayscale data or
or (n_images, height, width, channels) for multi-channel data.
For spectra, 2D stack of spectra with dimensions (length,)
X_test:
3D or 4D stack of test images or 2D stack of spectra with
the same dimensions as for the X_train (Default: None)
y_train:
Vector with labels of dimension (n_images,), where n_images
is a number of training images/spectra
y_train:
X_test:
3D or 4D stack of test images or 2D stack of spectra with
the same dimensions as for the X_train (Default: None)
y_test:
Vector with labels of dimension (n_images,), where n_images
is a number of test images/spectra
loss:
reconstruction loss function, "ce" or "mse" (Default: "mse")
**cont_capacity (list):
List containing (min_capacity, max_capacity, num_iters, gamma_z)
parameters to control the capacity of the continuous latent
channels. Default values: [0.0, 5.0, 25000, 30].
Based on https://arxiv.org/abs/1804.00104
List containing (max_capacity, num_iters, gamma) parameters
to control the capacity of the continuous latent channel.
Default values: [5.0, 25000, 30].
Based on https://arxiv.org/pdf/1804.03599.pdf & https://arxiv.org/abs/1804.00104
**disc_capacity (list):
List containing (min_capacity, max_capacity, num_iters, gamma_c)
parameters to control the capacity of the discrete latent channels.
Default values: [0.0, 5.0, 25000, 30].
Based on https://arxiv.org/abs/1804.00104
List containing (max_capacity, num_iters, gamma) parameters
to control the capacity of the discrete latent channel(s).
Default values: [5.0, 25000, 30].
Based on https://arxiv.org/pdf/1804.03599.pdf & https://arxiv.org/abs/1804.00104
**filename (str):
file path for saving model aftereach training cycle ("epoch")
"""
Expand Down
Loading

0 comments on commit 143a707

Please sign in to comment.