Skip to content

Commit

Permalink
Merge pull request #49 from valentingol/conditional
Browse files Browse the repository at this point in the history
🆙 Update to 2.1.0
  • Loading branch information
valentingol authored Sep 27, 2022
2 parents 4344d18 + 76b8d51 commit eca2512
Show file tree
Hide file tree
Showing 19 changed files with 375 additions and 154 deletions.
3 changes: 3 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[report]
omit=*/__init__.py
tests/*
41 changes: 27 additions & 14 deletions apps/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import os.path as osp
from typing import Optional

import numpy as np
import torch
Expand Down Expand Up @@ -75,42 +76,54 @@ def test(config: ConfigType) -> None:
colored_pixel_maps = colorize_pixel_map(pixel_maps)
images, attn_list = generator.generate(z_input, pixel_maps,
with_attn=True)
_, _, proba_map = generator.proba_map(z_input, pixel_maps[0])
else:
colored_pixel_maps = None
proba_map = None
with torch.no_grad():
images, attn_list = generator.generate(z_input, with_attn=True)

# Save sample images in a grid
# Save and show sample images in a grid
img_out_dir = osp.join(config.output_dir, config.run_name, 'samples')
img_out_path = osp.join(img_out_dir, f'test_samples_step_{step}.png')
img_grid = to_img_grid(images)
pil_images = Image.fromarray(img_grid)
os.makedirs(img_out_dir, exist_ok=True)
if colored_pixel_maps:
colored_pixel_maps.show(title=f'Test Samples (run {config.run_name},'
f' step {step})')
cond_save_path = img_out_path.replace('samples', 'cond_pixels')
os.makedirs(osp.dirname(cond_save_path), exist_ok=True)
colored_pixel_maps.save(cond_save_path)
pil_images.show(title=f'Test Samples (run {config.run_name}, step {step})')
pil_images.save(img_out_path)

if config.save_attn:
# Save attention
save_and_show(img_grid, img_out_path)

# Save and show other images (if any)
cond_save_path = img_out_path.replace('samples', 'cond_pixels')
save_and_show(colored_pixel_maps, cond_save_path)
proba_save_path = img_out_path.replace('samples', 'proba_map')
save_and_show(proba_map, proba_save_path)

# Save attention (if save_attn is True)
if config.save_attn and attn_list != []:
attn_out_path = osp.join(config.output_dir, config.run_name,
'attention', 'test_gen_attn_step')
os.makedirs(attn_out_path, exist_ok=True)
attn_list = [attn.detach().cpu().numpy() for attn in attn_list]
for i, attn in enumerate(attn_list):
np.save(osp.join(attn_out_path, f'attn_{i}_step_{step}.npy'), attn)

# Compute reference indicators if not already saved
compute_save_indicators(data_loader, config)
# Compute and print metrics
metrics = evaluate(gen=generator, config=config, training=False, step=step,
save_json=False, save_csv=True)
print("Metrics w.r.t training set:")
print_metrics(metrics)


def save_and_show(image: Optional[np.ndarray], path: str) -> None:
"""Save and show image using PIL."""
if image is None:
return
image_pil = Image.fromarray(image)
image_pil.show()
dir_path, _ = osp.split(path)
os.makedirs(dir_path, exist_ok=True)
image_pil.save(path)


if __name__ == '__main__':
global_config = GlobalConfig.build_from_argv(
fallback='configs/exp/base.yaml')
Expand Down
2 changes: 1 addition & 1 deletion configs/default/training.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ interrupt_threshold: -1.0
save_boxes: True # save the metric boxes
adv_loss: wgan-gp # could be either 'wgan-gp' or 'hinge'
mixed_precision: False # if True, use float16 instead of float32 for training
cond_penalty: 0.25 # only used for conditional model, weight for conditional loss
cond_penalty: 1000 # only used for conditional model, weight for conditional loss

g_ema_decay: 1.0 # decay of generator's exponential moving average (1.0 = no decay)
d_ema_decay: 1.0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Installation
config = {
'name': 'sagan-facies-modeling',
'version': '2.0.2',
'version': '2.1.0',
'description': 'Facies modeling with SAGAN.',
'author': 'Valentin Goldite',
'author_email': 'valentin.goldite@gmail.com',
Expand Down
18 changes: 17 additions & 1 deletion tests/utils/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pytest_check import check_func
from torch.utils.data import DataLoader

from utils.configs import GlobalConfig
from utils.configs import ConfigType, GlobalConfig
from utils.data.data_loader import DistributedDataLoader


Expand Down Expand Up @@ -64,6 +64,22 @@ def __len__(self) -> int:
)


class AttnMock(torch.nn.Module):
"""Mock class for SelfAttention."""

def __init__(self, in_dim: int, attention_config: ConfigType):
"""Init."""
# pylint: disable=unused-argument
super().__init__()
self.n_heads = attention_config.n_heads

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
attn_dim = x.shape[2] * x.shape[3]
attn = torch.zeros(x.shape[0], self.n_heads, attn_dim, attn_dim)
return x, attn


@check_func
def check_allclose(arr1: np.ndarray, arr2: np.ndarray) -> None:
"""Check if two arrays are all close."""
Expand Down
52 changes: 35 additions & 17 deletions tests/utils/data/test_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Tests for utils/data/data_loader.py."""

import os
from typing import Tuple

import numpy as np
import pytest
import pytest_check as check
import torch
from pytest_mock import MockerFixture

from utils.configs import Configuration, GlobalConfig
from utils.configs import GlobalConfig
from utils.data.data_loader import (DatasetCond2D, DatasetUncond2D,
DistributedDataLoader)

Expand All @@ -18,13 +20,6 @@ def dataset_path() -> str:
return 'tests/utils/data/tmp_dataset.npy'


@pytest.fixture
def data_config() -> Configuration:
"""Return data sub-config object for testing."""
return GlobalConfig.build_from_argv(
fallback='configs/unittest/data32.yaml').data


def create_test_dataset(dataset_path: str) -> None:
"""Create test dataset (if not exist)."""
if not os.path.exists(dataset_path):
Expand All @@ -33,12 +28,31 @@ def create_test_dataset(dataset_path: str) -> None:
np.save(dataset_path, dataset)


def mock_process(mocker: MockerFixture) -> None:
"""Mock functions from utils.data.process."""
data_one_hot = np.random.rand(7, 5, 20, 5).astype(np.float32)
data_resize = np.random.rand(10, 20, 5).astype(np.float32)
data_crop = np.random.rand(10, 10, 5).astype(np.float32)
pixels = np.random.rand(10, 10, 5).astype(np.float32)
# normalize output vectors
data_resize /= np.sum(data_resize, axis=-1, keepdims=True)
data_crop /= np.sum(data_crop, axis=-1, keepdims=True)

mocker.patch('utils.data.process.to_one_hot_np', return_value=data_one_hot)
mocker.patch('utils.data.process.resize_np', return_value=data_resize)
mocker.patch('utils.data.process.random_crop_np', return_value=data_crop)
mocker.patch('utils.data.process.sample_pixels_2d_np', return_value=pixels)


def test_dataset_uncond_2d(dataset_path: str,
data_config: Configuration) -> None:
configs: Tuple[GlobalConfig, GlobalConfig],
mocker: MockerFixture) -> None:
"""Test DatasetUncond2D."""
mock_process(mocker)
config, _ = configs
create_test_dataset(dataset_path)
dataset = DatasetUncond2D(dataset_path=dataset_path, data_size=10,
data_config=data_config,
data_config=config.data,
augmentation_fn=lambda x: 2 * x)
check.equal(len(dataset), 7)
sample = dataset[0]
Expand All @@ -47,7 +61,6 @@ def test_dataset_uncond_2d(dataset_path: str,
data = sample[0]
check.is_instance(data, torch.Tensor)
check.equal(data.size(), (5, 10, 10))
check.greater_equal(torch.min(data), 0)
# Sum of values should be 2 with lambda x: 2 * x augmentation
print(torch.sum(data, dim=1))
check.is_true(torch.allclose(torch.sum(data, dim=0),
Expand All @@ -56,11 +69,14 @@ def test_dataset_uncond_2d(dataset_path: str,


def test_dataset_cond_2d(dataset_path: str,
data_config: Configuration) -> None:
configs: Tuple[GlobalConfig, GlobalConfig],
mocker: MockerFixture) -> None:
"""Test DatasetCond2D."""
mock_process(mocker)
config, _ = configs
create_test_dataset(dataset_path)
dataset = DatasetCond2D(dataset_path=dataset_path, data_size=10,
data_config=data_config,
data_config=config.data,
augmentation_fn=lambda x: 2 * x)
check.equal(len(dataset), 7)
sample = dataset[0]
Expand All @@ -71,21 +87,23 @@ def test_dataset_cond_2d(dataset_path: str,
check.is_instance(pixel_maps, torch.Tensor)
check.equal(data.size(), (5, 10, 10))
check.equal(data.size(), (5, 10, 10))
check.greater_equal(torch.min(data), 0)
# Sum of values should be 2 with lambda x: 2 * x augmentation
check.is_true(torch.allclose(torch.sum(data, dim=0),
torch.tensor(2, dtype=torch.float32)))
os.remove(dataset_path)


def test_distributed_dataloader(dataset_path: str,
data_config: Configuration) -> None:
configs: Tuple[GlobalConfig, GlobalConfig],
mocker: MockerFixture) -> None:
"""Test DistributedDataLoader."""
mock_process(mocker)
config, _ = configs
create_test_dataset(dataset_path)
# Case training = True
dataloader = DistributedDataLoader(
dataset_path=dataset_path, data_size=10, training=True,
data_config=data_config, dataset_class=DatasetUncond2D,
data_config=config.data, dataset_class=DatasetUncond2D,
augmentation_fn=lambda x: 2 * x
).loader()

Expand All @@ -100,7 +118,7 @@ def test_distributed_dataloader(dataset_path: str,
# Case training = False
dataloader = DistributedDataLoader(
dataset_path=dataset_path, data_size=10, training=False,
data_config=data_config, dataset_class=DatasetUncond2D,
data_config=config.data, dataset_class=DatasetUncond2D,
augmentation_fn=lambda x: 2 * x
).loader()

Expand Down
16 changes: 13 additions & 3 deletions tests/utils/data/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from skimage.measure import label

from tests.utils.conftest import check_allclose
from utils.data.process import (color_data_np, random_crop_np, resize_np,
sample_pixels_2d_np, to_img_grid,
to_one_hot_np)
from utils.data.process import (color_data_np, continuous_color_data_np,
random_crop_np, resize_np, sample_pixels_2d_np,
to_img_grid, to_one_hot_np)


@pytest.fixture
Expand Down Expand Up @@ -83,6 +83,16 @@ def test_color_data_np(data_int: np.ndarray) -> None:
check.less_equal(np.max(color_data), 255)


def test_continuous_color_data_np(data_one_hot: np.ndarray) -> None:
"""Test continuous_color_data_np."""
cont_data_on_hot = np.where(data_one_hot == 1, 0.8, 0.2)
color_data = continuous_color_data_np(cont_data_on_hot)
check.equal(color_data.shape, (4, 5, 3))
check.equal(color_data.dtype, np.uint8)
check.greater_equal(np.min(color_data), 0)
check.less_equal(np.max(color_data), 255)


def test_to_img_grid() -> None:
"""Test to_img_grid."""
batched_images = np.random.rand(65, 8, 8, 3)
Expand Down
54 changes: 49 additions & 5 deletions tests/utils/gan/cond_sagan/test_cond_modules.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
"""Tests for utils/gan/cond_sagan/module.py."""

from typing import Tuple

import numpy as np
import pytest
import pytest_check as check
import torch
from pytest_mock import MockerFixture

from tests.utils.conftest import AttnMock
from utils.configs import GlobalConfig
from utils.gan.cond_sagan.modules import CondSAGenerator


def test_sa_generator(configs: Tuple[GlobalConfig, GlobalConfig]) -> None:
"""Test SAGenerator."""
@pytest.fixture
def gen(configs: Tuple[GlobalConfig, GlobalConfig],
mocker: MockerFixture) -> CondSAGenerator:
"""Return generator for tests."""
mocker.patch('utils.gan.initialization.init_weights')
mocker.patch('utils.gan.attention.SelfAttention', AttnMock)
mocker.patch('utils.gan.spectral.SpectralNorm', side_effect=lambda x: x)
config32, _ = configs
return CondSAGenerator(n_classes=4, model_config=config32.model)


def test_sa_generator_fwd(gen: CondSAGenerator) -> None:
"""Test CondSAGenerator.forward."""
pixel_maps = torch.randint(0, 2, size=(5, 4, 32, 32), dtype=torch.float32)
gen = CondSAGenerator(n_classes=4, model_config=config32.model)
z = torch.rand(size=(5, 128), dtype=torch.float32)
data, att_list = gen(z, pixel_maps, with_attn=True)
check.equal(data.shape, (5, 4, 32, 32))
Expand All @@ -24,10 +37,41 @@ def test_sa_generator(configs: Tuple[GlobalConfig, GlobalConfig]) -> None:
data = gen(z, pixel_maps, with_attn=False)
check.is_instance(data, torch.Tensor)

# Test generate method

def test_sa_generator_generate(gen: CondSAGenerator,
mocker: MockerFixture) -> None:
"""Test SAGenerator.generate."""
mocker.patch('utils.data.process.color_data_np',
return_value=np.random.randint(0, 256, (5, 32, 32, 3),
dtype=np.uint8))
pixel_maps = torch.randint(0, 2, size=(5, 4, 32, 32), dtype=torch.float32)
z = torch.rand(size=(5, 128), dtype=torch.float32)
images, attn_list = gen.generate(z, pixel_maps, with_attn=True)
check.is_instance(images, np.ndarray)
check.equal(images.shape, (5, 32, 32, 3))
check.equal(len(attn_list), 3)
images, attn_list = gen.generate(z, pixel_maps, with_attn=False)
check.equal(images.shape, (5, 32, 32, 3))
check.is_instance(images, np.ndarray)
check.equal(attn_list, [])


def test_sa_generator_proba_map(gen: CondSAGenerator,
mocker: MockerFixture) -> None:
"""Test SAGenerator.proba_map."""
mocker.patch('utils.data.process.continuous_color_data_np',
return_value=np.random.randint(0, 256, (32, 32, 3),
dtype=np.uint8))
pixel_map = torch.randint(0, 2, size=(4, 32, 32), dtype=torch.float32)
z = torch.rand(size=(5, 128), dtype=torch.float32)
# Case batch_size = None
proba_mean, proba_std, _ = gen.proba_map(z, pixel_map, batch_size=None)
check.is_instance(proba_mean, np.ndarray)
check.is_true(proba_mean.shape == (32, 32, 4))
check.greater_equal(proba_mean.min(), 0.0)
check.is_true(np.allclose(np.sum(proba_mean, axis=-1), 1.0))
check.is_instance(proba_std, np.ndarray)
check.is_true(proba_std.shape == (32, 32, 4))
check.greater_equal(proba_std.min(), 0.0)
# Case batch_size = 1
proba_mean, _, _ = gen.proba_map(z, pixel_map, batch_size=1)
check.is_true(proba_mean.shape == (32, 32, 4))
27 changes: 27 additions & 0 deletions tests/utils/gan/test_initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Tests for utils.gan.initialization.py."""

import pytest_check as check
import torch
from torch import nn

from utils.gan.initialization import init_weights


def test_init_weights() -> None:
"""Test init_weights."""
module = nn.Sequential(nn.Conv2d(3, 3, 3), nn.Conv2d(3, 3, 3))
weights_1 = module[0].weight.clone()
init_weights(module, 'default')
weights_2 = module[0].weight.clone()
check.is_true(torch.allclose(weights_1, weights_2))
init_weights(module, 'orthogonal')
weights_3 = module[0].weight.clone()
check.is_false(torch.allclose(weights_2, weights_3))
init_weights(module, 'glorot')
weights_4 = module[0].weight.clone()
check.is_false(torch.allclose(weights_3, weights_4))
init_weights(module, 'normal')
weights_5 = module[0].weight.clone()
check.is_false(torch.allclose(weights_4, weights_5))
with check.raises(ValueError):
init_weights(module, 'UNKNOWN')
Loading

0 comments on commit eca2512

Please sign in to comment.