Skip to content

Commit

Permalink
Merge pull request #20 from deel-ai/feat-parametrization-for-normaliz…
Browse files Browse the repository at this point in the history
…ation

Add parametrization for normalization
  • Loading branch information
franckma31 authored Oct 14, 2024
2 parents 3886890 + b8a0b3d commit 49b1a50
Show file tree
Hide file tree
Showing 18 changed files with 166 additions and 243 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-lints.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: [3.6, 3.7, 3.8]
python-version: [3.7, 3.8, 3.9]

steps:
- uses: actions/checkout@v1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: [3.6, 3.7, 3.8]
python-version: [3.7, 3.8, 3.9]

steps:
- uses: actions/checkout@v1
Expand Down
1 change: 1 addition & 0 deletions deel/torchlip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,5 @@
"Sequential",
"SpectralConv2d",
"SpectralLinear",
"vanilla_model",
]
1 change: 1 addition & 0 deletions deel/torchlip/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from .loss import NegKRLoss
from .module import LipschitzModule
from .module import Sequential
from .module import vanilla_model
from .pooling import ScaledAdaptiveAvgPool2d
from .pooling import ScaledAvgPool2d
from .pooling import ScaledL2NormPool2d
Expand Down
8 changes: 4 additions & 4 deletions deel/torchlip/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np
import torch
from torch.nn.common_types import _size_2_t
from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import spectral_norm

from ..utils import bjorck_norm
from ..utils import DEFAULT_NITER_BJORCK
Expand Down Expand Up @@ -111,8 +111,8 @@ def __init__(
n_power_iterations=niter_spectral,
)
bjorck_norm(self, name="weight", n_iterations=niter_bjorck)
lconv_norm(self)
self.register_forward_pre_hook(self._hook)
lconv_norm(self, name="weight")
self.apply_lipschitz_factor()

def vanilla_export(self):
layer = torch.nn.Conv2d(
Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(

frobenius_norm(self, name="weight", disjoint_neurons=False)
lconv_norm(self)
self.register_forward_pre_hook(self._hook)
self.apply_lipschitz_factor()

def vanilla_export(self):
layer = torch.nn.Conv2d(
Expand Down
6 changes: 3 additions & 3 deletions deel/torchlip/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# CRIAQ and ANITI - https://www.deel.ai/
# =====================================================================================
import torch
from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import spectral_norm

from ..utils import bjorck_norm
from ..utils import DEFAULT_NITER_BJORCK
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
n_power_iterations=niter_spectral,
)
bjorck_norm(self, name="weight", n_iterations=niter_bjorck)
self.register_forward_pre_hook(self._hook)
self.apply_lipschitz_factor()

def vanilla_export(self) -> torch.nn.Linear:
layer = torch.nn.Linear(
Expand Down Expand Up @@ -142,7 +142,7 @@ def __init__(
self.bias.data.fill_(0.0)

frobenius_norm(self, name="weight", disjoint_neurons=disjoint_neurons)
self.register_forward_pre_hook(self._hook)
self.apply_lipschitz_factor()

def vanilla_export(self):
layer = torch.nn.Linear(
Expand Down
44 changes: 41 additions & 3 deletions deel/torchlip/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,51 @@
for condensation and vanilla exportation.
"""
import abc
from collections import OrderedDict
import copy
import logging
import math
from collections import OrderedDict
from typing import Any

import numpy as np
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
from torch.nn import Sequential as TorchSequential

logger = logging.getLogger("deel.torchlip")


def vanilla_model(model: nn.Module):
"""Convert lipschitz modules into their non-lipschitz counterpart (for
instance, SpectralConv2d layers become Conv2d layers).
Warning: This function modifies the model in-place.
Args:
model (nn.Module): Lipschitz neural network
"""
for n, module in model.named_children():
if len(list(module.children())) > 0:
# compound module, go inside it
vanilla_model(module)

if isinstance(module, LipschitzModule):
# simple module
setattr(model, n, module.vanilla_export())


class _LipschitzCoefMultiplication(nn.Module):
"""Parametrization module for lipschitz global coefficient multiplication."""

def __init__(self, coef: float):
super().__init__()
self._coef = coef

def forward(self, weight: torch.Tensor) -> torch.Tensor:
return self._coef * weight


class LipschitzModule(abc.ABC):
"""
This class allow to set lipschitz factor of a layer. Lipschitz layer must inherit
Expand All @@ -58,8 +91,13 @@ class LipschitzModule(abc.ABC):
def __init__(self, coefficient_lip: float = 1.0):
self._coefficient_lip = coefficient_lip

def _hook(self, module, inputs):
setattr(module, "weight", getattr(module, "weight") * self._coefficient_lip)
def apply_lipschitz_factor(self):
"""Multiply the layer weights by a lipschitz factor."""
if self._coefficient_lip == 1.0:
return
parametrize.register_parametrization(
self, "weight", _LipschitzCoefMultiplication(self._coefficient_lip)
)

@abc.abstractmethod
def vanilla_export(self):
Expand Down
1 change: 1 addition & 0 deletions deel/torchlip/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .lconv_norm import remove_lconv_norm
from .sqrt_eps import sqrt_with_gradeps # noqa: F401


DEFAULT_NITER_BJORCK = 15
DEFAULT_NITER_SPECTRAL = 3
DEFAULT_NITER_SPECTRAL_INIT = 10
Expand Down
66 changes: 29 additions & 37 deletions deel/torchlip/utils/bjorck_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,53 +24,43 @@
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
# =====================================================================================
from typing import Any
from typing import TypeVar

import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize

from ..normalizers import bjorck_normalization
from ..normalizers import DEFAULT_NITER_BJORCK
from .hook_norm import HookNorm


class BjorckNorm(HookNorm):
"""
Bjorck Normalization from https://arxiv.org/abs/1811.05381
"""

n_iterations: int

def __init__(self, module: torch.nn.Module, name: str, n_iterations: int):
super().__init__(module, name)
class _BjorckNorm(nn.Module):
def __init__(self, weight: torch.Tensor, n_iterations: int) -> None:
super().__init__()
self.n_iterations = n_iterations
self.register_buffer("_w_bjorck", weight.data)

def compute_weight(self, module: torch.nn.Module, inputs: Any) -> torch.Tensor:
return bjorck_normalization(self.weight(module), self.n_iterations)

@staticmethod
def apply(module: torch.nn.Module, name: str, n_iterations: int) -> "BjorckNorm":
return BjorckNorm(module, name, n_iterations)


T_module = TypeVar("T_module", bound=torch.nn.Module)
def forward(self, weight: torch.Tensor) -> torch.Tensor:
if self.training:
w_bjorck = bjorck_normalization(weight, self.n_iterations)
self._w_bjorck = w_bjorck.data
else:
w_bjorck = self._w_bjorck
return w_bjorck


def bjorck_norm(
module: T_module, name: str = "weight", n_iterations: int = DEFAULT_NITER_BJORCK
) -> T_module:
module: nn.Module, name: str = "weight", n_iterations: int = DEFAULT_NITER_BJORCK
) -> nn.Module:
r"""
Applies Bjorck normalization to a parameter in the given module.
Bjorck normalization ensures that all eigen values of a vectors remain close or
equal to one during training. If the dimension of the weight tensor is greater than
2, it is reshaped to 2D for iteration.
This is implemented via a hook that applies Bjorck normalization before every
``forward()`` call.
This is implemented via a Bjorck normalization parametrization.
.. note::
It is recommended to use :py:func:`torch.nn.utils.spectral_norm` before
this hook to greatly reduce the number of iterations required.
It is recommended to use :py:func:`torch.nn.utils.parameterize.spectral_norm`
before this hook to greatly reduce the number of iterations required.
See `Sorting out Lipschitz function approximation
<https://arxiv.org/abs/1811.05381>`_.
Expand All @@ -92,11 +82,14 @@ def bjorck_norm(
See Also:
:py:func:`deel.torchlip.normalizers.bjorck_normalization`
"""
BjorckNorm.apply(module, name, n_iterations)
weight = getattr(module, name, None)
parametrize.register_parametrization(
module, name, _BjorckNorm(weight, n_iterations)
)
return module


def remove_bjorck_norm(module: T_module, name: str = "weight") -> T_module:
def remove_bjorck_norm(module: nn.Module, name: str = "weight") -> nn.Module:
r"""
Removes the Bjorck normalization reparameterization from a module.
Expand All @@ -108,10 +101,9 @@ def remove_bjorck_norm(module: T_module, name: str = "weight") -> T_module:
>>> m = bjorck_norm(nn.Linear(20, 40))
>>> remove_bjorck_norm(m)
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, BjorckNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module

raise ValueError("bjorck_norm of '{}' not found in {}".format(name, module))
for key, m in module.parametrizations[name]._modules.items():
if isinstance(m, _BjorckNorm):
if len(module.parametrizations["weight"]) == 1:
parametrize.remove_parametrizations(module, name)
else:
del module.parametrizations[name]._modules[key]
48 changes: 17 additions & 31 deletions deel/torchlip/utils/frobenius_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,23 @@
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
# =====================================================================================
from typing import Any
from typing import TypeVar

import torch

from .hook_norm import HookNorm
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize


class FrobeniusNorm(HookNorm):
def __init__(self, module: torch.nn.Module, name: str, disjoint_neurons: bool):
super().__init__(module, name)
class _FrobeniusNorm(nn.Module):
def __init__(self, disjoint_neurons: bool) -> None:
super().__init__()
self.dim_norm = 1 if disjoint_neurons else None

def compute_weight(self, module: torch.nn.Module, inputs: Any) -> torch.Tensor:
w: torch.Tensor = self.weight(module)
return w / torch.norm(w, dim=self.dim_norm, keepdim=True) # type: ignore

@staticmethod
def apply(
module: torch.nn.Module, name: str, disjoint_neurons: bool
) -> "FrobeniusNorm":
return FrobeniusNorm(module, name, disjoint_neurons)


T_module = TypeVar("T_module", bound=torch.nn.Module)
def forward(self, weight: torch.Tensor) -> torch.Tensor:
return weight / torch.norm(weight, dim=self.dim_norm, keepdim=True)


def frobenius_norm(
module: T_module, name: str = "weight", disjoint_neurons: bool = True
) -> T_module:
module: nn.Module, name: str = "weight", disjoint_neurons: bool = True
) -> nn.Module:
r"""
Applies Frobenius normalization to a parameter in the given module.
Expand All @@ -78,11 +65,11 @@ def frobenius_norm(
Linear(in_features=20, out_features=40, bias=True)
"""
FrobeniusNorm.apply(module, name, disjoint_neurons)
parametrize.register_parametrization(module, name, _FrobeniusNorm(disjoint_neurons))
return module


def remove_frobenius_norm(module: T_module, name: str = "weight") -> T_module:
def remove_frobenius_norm(module: nn.Module, name: str = "weight") -> nn.Module:
r"""
Removes the Frobenius normalization reparameterization from a module.
Expand All @@ -95,10 +82,9 @@ def remove_frobenius_norm(module: T_module, name: str = "weight") -> T_module:
>>> m = frobenius_norm(nn.Linear(20, 40))
>>> remove_frobenius_norm(m)
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, FrobeniusNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module

raise ValueError("frobenius_norm of '{}' not found in {}".format(name, module))
for key, m in module.parametrizations[name]._modules.items():
if isinstance(m, _FrobeniusNorm):
if len(module.parametrizations["weight"]) == 1:
parametrize.remove_parametrizations(module, name)
else:
del module.parametrizations[name]._modules[key]
Loading

0 comments on commit 49b1a50

Please sign in to comment.