Skip to content

Commit

Permalink
Merge pull request #12 from valentingol/dev
Browse files Browse the repository at this point in the history
🆙 Upgrade to 0.2.2
  • Loading branch information
valentingol authored Jul 29, 2022
2 parents bb0b9f9 + c19694f commit 14fd93a
Show file tree
Hide file tree
Showing 15 changed files with 210 additions and 75 deletions.
8 changes: 5 additions & 3 deletions apps/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
from utils.configs import ConfigType, GlobalConfig
from utils.data.process import to_img_grid
from utils.sagan.modules import SAGenerator
from utils.train.random_utils import set_global_seed


def test(config: ConfigType) -> None:
"""Test the generator."""
# For reproducibility
set_global_seed(seed=config.seed)

architecture = config.model.architecture
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = config.test_batch_size
Expand All @@ -29,9 +33,7 @@ def test(config: ConfigType) -> None:

if architecture == 'sagan':
generator = SAGenerator(n_classes=n_classes,
data_size=config.model.data_size,
z_dim=config.model.z_dim,
conv_dim=config.model.g_conv_dim).to(device)
model_config=config.model).to(device)

z_input = torch.randn(batch_size, config.model.z_dim, device=device)
generator.load_state_dict(torch.load(model_path))
Expand Down
8 changes: 7 additions & 1 deletion apps/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from utils.configs import ConfigType, GlobalConfig
from utils.data.data_loader import DataLoader2DFacies
from utils.sagan.trainer import TrainerSAGAN
from utils.train.random_utils import set_global_seed


def train(config: ConfigType) -> None:
Expand All @@ -19,8 +20,12 @@ def train(config: ConfigType) -> None:
batch_size = config.training.batch_size
architecture = config.model.architecture

# For fast training
# Improve reproducibility
set_global_seed(seed=config.seed)

# For faster training (but reduce reproducibility!)
cudnn.benchmark = True

# Data loader
data_loader = DataLoader2DFacies(dataset_path=config.dataset_path,
data_size=config.model.data_size,
Expand Down Expand Up @@ -84,5 +89,6 @@ def main() -> None:
if __name__ == '__main__':
global_config = GlobalConfig.build_from_argv(
fallback='configs/exp/base.yaml')

global_config.save(osp.join(global_config.config_save_path, 'config'))
main()
1 change: 1 addition & 0 deletions configs/default/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
run_name: null # (should be overwritten by the experiment config)
config_save_path: null # (should be overwritten by the experiment config)
dataset_path: null # (should be overwritten by the experiment config)
seed: 0
use_wandb: False
recover_model_step: 0 # the step to recover the model, 0 to not recover
save_attn: False
Expand Down
1 change: 1 addition & 0 deletions configs/default/model.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
--- !model
architecture: sagan
init_method: 'default' # could be either 'default' or 'orthogonal'
data_size: 64 # could be either 32 or 64
full_values: True
z_dim: 128
Expand Down
1 change: 1 addition & 0 deletions configs/default/training.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ weight_decay: 0.0
# Step sizes
total_step: 100000
total_time: -1 # in sec, < 0 means no time limit
ema_start_step: 0 # only used if g_ema_decay < 1.0
log_step: 10
sample_step: 400 # save generated images every n steps
model_save_step: 1200 # save generator and discriminator every n steps
5 changes: 5 additions & 0 deletions configs/unittest/data32.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ dataset_path: null # use custom simple dataset instead
save_attn: True

model.data_size: 32
model.d_conv_dim: 64
model.g_conv_dim: 64
model.full_values: True
model.init_method: orthogonal

training.adv_loss: hinge
training.g_ema_decay: 0.999
training.ema_start_step: 0

training.batch_size: 2
training.log_step: 1
Expand Down
3 changes: 3 additions & 0 deletions configs/unittest/data64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ dataset_path: null # use custom simple dataset instead
save_attn: True

model.data_size: 64
model.d_conv_dim: 64
model.g_conv_dim: 64
model.full_values: False
model.init_method: normal

training.adv_loss: wgan-gp

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Installation
config = {
'name': 'sagan-facies-modeling',
'version': '0.2.1',
'version': '0.2.2',
'description': 'Facies modeling with SAGAN.',
'author': 'Valentin Goldite',
'author_email': 'valentin.goldite@gmail.com',
Expand Down
42 changes: 30 additions & 12 deletions tests/utils/sagan/test_modules.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
"""Tests for sagan/modules.py."""

from typing import Tuple

import pytest
import torch

from utils.configs import GlobalConfig
from utils.sagan.modules import SADiscriminator, SAGenerator, SelfAttention


@pytest.fixture
def configs() -> Tuple[GlobalConfig, GlobalConfig]:
"""Return configs with data size 32 and 64"""
config32 = GlobalConfig().build_from_argv(
fallback='configs/unittest/data32.yaml')
config64 = GlobalConfig().build_from_argv(
fallback='configs/unittest/data64.yaml')
return config32, config64


def test_self_attention() -> None:
"""Test SelfAttention."""
self_att = SelfAttention(in_dim=200)
Expand All @@ -17,21 +31,21 @@ def test_self_attention() -> None:
torch.tensor(1, dtype=torch.float32))


def test_sa_discriminator() -> None:
def test_sa_discriminator(configs: Tuple[GlobalConfig, GlobalConfig]) -> None:
"""Test SADiscriminator."""
# data_size = 32, batch_size = 5, full_values = False
disc = SADiscriminator(n_classes=4, data_size=32, conv_dim=64,
full_values=False)
config32, config64 = configs

# config 32
disc = SADiscriminator(n_classes=4, model_config=config32.model)
x = torch.rand(size=(5, 4, 32, 32), dtype=torch.float32)
x = x / torch.sum(x, dim=1, keepdim=True) # normalize
preds, att_list = disc(x)
assert len(att_list) == 1
assert att_list[0].shape == (5, 16, 16)
assert preds.shape == (5,)

# data_size = 64, batch_size = 1, full_values = True
disc = SADiscriminator(n_classes=4, data_size=64, conv_dim=64,
full_values=True)
# config 64
disc = SADiscriminator(n_classes=4, model_config=config64.model)
x = torch.rand(size=(1, 4, 64, 64), dtype=torch.float32)
x = x / torch.sum(x, dim=1, keepdim=True) # normalize
preds, att_list = disc(x)
Expand All @@ -41,23 +55,27 @@ def test_sa_discriminator() -> None:
assert att_list[1].shape == (1, 16, 16)


def test_sa_generator() -> None:
def test_sa_generator(configs: Tuple[GlobalConfig, GlobalConfig]) -> None:
"""Test SAGenerator."""
# data_size = 32, batch_size = 5
gen = SAGenerator(n_classes=4, data_size=32, z_dim=128, conv_dim=64)
config32, config64 = configs

# config 32
gen = SAGenerator(n_classes=4, model_config=config32.model)
z = torch.rand(size=(5, 128), dtype=torch.float32)
data, att_list = gen(z)
assert data.shape == (5, 4, 32, 32)
assert len(att_list) == 1
assert att_list[0].shape == (5, 256, 256)
# data_size = 64, batch_size = 1
gen = SAGenerator(n_classes=4, data_size=64, z_dim=128, conv_dim=64)

# config 64
gen = SAGenerator(n_classes=4, model_config=config64.model)
z = torch.rand(size=(1, 128), dtype=torch.float32)
data, att_list = gen(z)
assert data.shape == (1, 4, 64, 64)
assert len(att_list) == 2
assert att_list[0].shape == (1, 256, 256)
assert att_list[1].shape == (1, 1024, 1024)

# Test generate method
images, _ = gen.generate(z, with_attn=True)
assert images.shape == (1, 64, 64, 3)
Expand Down
17 changes: 17 additions & 0 deletions tests/utils/train/test_random_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Test /utils/train/random_utils.py."""

import random

import numpy as np
import torch

from utils.train.random_utils import set_global_seed


def test_set_global_seed() -> None:
"""Test set_global_seed."""
set_global_seed(0)
# Test the expected values for this particular seed.
assert torch.randint(1000, size=(1,)) == 44
assert np.random.randint(0, 1000) == 684
assert random.randint(0, 1000) == 864
15 changes: 15 additions & 0 deletions tests/utils/train/test_time_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Test /utils/train/time_utils.py."""

import time

from utils.train.time_utils import get_delta_eta


def test_get_delta_eta() -> None:
"""Test get_delta_eta."""
params = {'start_step': 1000, 'step': 2000, 'total_step': 3000}
start_time = time.time() - 90 # simulate 1 minute 30 seconds elapsed
delta_str, eta_str = get_delta_eta(start_time=start_time, **params)
# NOTE: add margin of error to taking into account the call time
assert delta_str in [f'00h01m{i}s' for i in range(25, 35)]
assert eta_str in [f'00h01m{i}s' for i in range(25, 35)]
81 changes: 61 additions & 20 deletions utils/sagan/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from torch import nn

from utils.configs import ConfigType
from utils.data.process import color_data_np
from utils.sagan.spectral import SpectralNorm

Expand All @@ -23,12 +24,13 @@ class SelfAttention(nn.Module):
Input feature map dimension (channels).
att_dim : int, optional
Attention map dimension for each query and key
(and value if full_values is False). By default, in_dim // 8.
(and value if full_values is False).
By default, in_dim // 8.
full_values : bool, optional
Whether to have value dimension equal to full dimension (in_dim)
or reduced to in_dim // 2. In the latter case, the output of the
attention is projected to full dimension by an additional
1*1 convolution. By default, True.
Whether to have value dimension equal to full dimension
(in_dim) or reduced to in_dim // 2. In the latter case,
the output of the attention is projected to full dimension
by an additional 1*1 convolution. By default, True.
"""

def __init__(self, in_dim: int, att_dim: Optional[int] = None,
Expand Down Expand Up @@ -96,37 +98,56 @@ def forward(self, x: torch.Tensor) -> TensorWithAttn:
class SADiscriminator(nn.Module):
"""Self-attention discriminator."""

def __init__(self, n_classes: int, data_size: int = 64,
conv_dim: int = 64, full_values: bool = True) -> None:
def __init__(self, n_classes: int, model_config: ConfigType) -> None:
super().__init__()
self.n_classes = n_classes
self.data_size = data_size
self.data_size = model_config.data_size

self.conv1 = self._make_disc_block(n_classes, conv_dim, kernel_size=4,
stride=2, padding=1)
self.conv1 = self._make_disc_block(n_classes, model_config.d_conv_dim,
kernel_size=4, stride=2, padding=1)

current_dim = conv_dim
current_dim = model_config.d_conv_dim
self.conv2 = self._make_disc_block(current_dim, current_dim * 2,
kernel_size=4, stride=2, padding=1)

current_dim = current_dim * 2
self.conv3 = self._make_disc_block(current_dim, current_dim * 2,
kernel_size=4, stride=2, padding=1)

self.attn1 = SelfAttention(current_dim * 2, full_values=full_values)
self.attn1 = SelfAttention(current_dim * 2,
full_values=model_config.full_values)

if self.data_size == 64:
current_dim = current_dim * 2
self.conv4 = self._make_disc_block(current_dim, current_dim * 2,
kernel_size=4, stride=2,
padding=1)
self.attn2 = SelfAttention(current_dim * 2,
full_values=full_values)
full_values=model_config.full_values)

current_dim = current_dim * 2
self.conv_last = nn.Sequential(
nn.Conv2d(current_dim, 1, kernel_size=4),)

self.init_weights(model_config.init_method)

def init_weights(self, init_method: str) -> None:
"""Initialize weights."""
if init_method == 'default':
return
for _, param in self.named_parameters():
if param.ndim == 4:
if init_method == 'orthogonal':
nn.init.orthogonal_(param)
elif init_method == 'glorot':
nn.init.xavier_uniform_(param)
elif init_method == 'normal':
nn.init.normal_(param, 0, 0.02)
else:
raise ValueError(
f'Unknown init method: {init_method}. Should be one '
'of "default", "orthogonal", "glorot", "normal".')

def _make_disc_block(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int,
padding: int) -> nn.Module:
Expand Down Expand Up @@ -179,42 +200,62 @@ def forward(self, x: torch.Tensor) -> TensorWithAttn:
class SAGenerator(nn.Module):
"""Self-attention generator."""

def __init__(self, n_classes: int, data_size: int = 64, z_dim: int = 128,
conv_dim: int = 64, full_values: bool = True) -> None:
def __init__(self, n_classes: int, model_config: ConfigType) -> None:
super().__init__()
self.n_classes = n_classes
self.data_size = data_size
self.data_size = model_config.data_size

repeat_num = int(np.log2(self.data_size)) - 3
mult = 2**repeat_num # 8 if data_size=64, 4 if data_size=32

self.conv1 = self._make_gen_block(z_dim, conv_dim * mult,
self.conv1 = self._make_gen_block(model_config.z_dim,
model_config.g_conv_dim * mult,
kernel_size=4, stride=1, padding=0)

current_dim = conv_dim * mult
current_dim = model_config.g_conv_dim * mult
self.conv2 = self._make_gen_block(current_dim, current_dim // 2,
kernel_size=4, stride=2, padding=1)

current_dim = current_dim // 2
self.conv3 = self._make_gen_block(current_dim, current_dim // 2,
kernel_size=4, stride=2, padding=1)

self.attn1 = SelfAttention(current_dim // 2, full_values=full_values)
self.attn1 = SelfAttention(current_dim // 2,
full_values=model_config.full_values)

if self.data_size == 64:
current_dim = current_dim // 2
self.conv4 = self._make_gen_block(current_dim, current_dim // 2,
kernel_size=4, stride=2,
padding=1)
self.attn2 = SelfAttention(current_dim // 2,
full_values=full_values)
full_values=model_config.full_values)

current_dim = current_dim // 2

self.conv_last = nn.Sequential(
nn.ConvTranspose2d(current_dim, n_classes, kernel_size=4, stride=2,
padding=1))

self.init_weights(model_config.init_method)

def init_weights(self, init_method: str) -> None:
"""Initialize weights."""
if init_method == 'default':
return
for _, param in self.named_parameters():
if param.ndim == 4:
if init_method == 'orthogonal':
nn.init.orthogonal_(param)
elif init_method == 'glorot':
nn.init.xavier_uniform_(param)
elif init_method == 'normal':
nn.init.normal_(param, 0, 0.02)
else:
raise ValueError(
f'Unknown init method: {init_method}. Should be one '
'of "default", "orthogonal", "glorot", "normal".')

def _make_gen_block(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int,
padding: int) -> nn.Module:
Expand Down
Loading

0 comments on commit 14fd93a

Please sign in to comment.