From 55fd48c93141eb447f834f992983a68651253a0f Mon Sep 17 00:00:00 2001 From: Yannick Prudent Date: Mon, 13 Feb 2023 13:57:18 +0100 Subject: [PATCH 1/6] feat: add parametrization --- deel/torchlip/modules/conv.py | 8 +- deel/torchlip/modules/linear.py | 6 +- deel/torchlip/modules/module.py | 23 +++- deel/torchlip/utils/__init__.py | 1 + deel/torchlip/utils/bjorck_norm.py | 66 +++++------ deel/torchlip/utils/frobenius_norm.py | 48 +++----- deel/torchlip/utils/hook_norm.py | 103 ------------------ deel/torchlip/utils/lconv_norm.py | 63 ++++++++--- tests/test_downsampling.py | 1 - ..._hook_norm.py => test_parametrizations.py} | 46 +++++--- tests/test_torch_lip_layers.py | 6 +- tests/test_upsampling.py | 1 - tests/test_vanilla_export.py | 4 +- 13 files changed, 163 insertions(+), 213 deletions(-) delete mode 100644 deel/torchlip/utils/hook_norm.py rename tests/{test_lip_hook_norm.py => test_parametrizations.py} (68%) diff --git a/deel/torchlip/modules/conv.py b/deel/torchlip/modules/conv.py index efe64fb..991f6c3 100644 --- a/deel/torchlip/modules/conv.py +++ b/deel/torchlip/modules/conv.py @@ -27,7 +27,7 @@ import numpy as np import torch from torch.nn.common_types import _size_2_t -from torch.nn.utils import spectral_norm +from torch.nn.utils.parametrizations import spectral_norm from ..utils import bjorck_norm from ..utils import DEFAULT_NITER_BJORCK @@ -111,8 +111,8 @@ def __init__( n_power_iterations=niter_spectral, ) bjorck_norm(self, name="weight", n_iterations=niter_bjorck) - lconv_norm(self) - self.register_forward_pre_hook(self._hook) + lconv_norm(self, name="weight") + self.apply_lipschitz_factor() def vanilla_export(self): layer = torch.nn.Conv2d( @@ -172,7 +172,7 @@ def __init__( frobenius_norm(self, name="weight", disjoint_neurons=False) lconv_norm(self) - self.register_forward_pre_hook(self._hook) + self.apply_lipschitz_factor() def vanilla_export(self): layer = torch.nn.Conv2d( diff --git a/deel/torchlip/modules/linear.py b/deel/torchlip/modules/linear.py index 8b249d7..de044cf 100644 --- a/deel/torchlip/modules/linear.py +++ b/deel/torchlip/modules/linear.py @@ -25,7 +25,7 @@ # CRIAQ and ANITI - https://www.deel.ai/ # ===================================================================================== import torch -from torch.nn.utils import spectral_norm +from torch.nn.utils.parametrizations import spectral_norm from ..utils import bjorck_norm from ..utils import DEFAULT_NITER_BJORCK @@ -87,7 +87,7 @@ def __init__( n_power_iterations=niter_spectral, ) bjorck_norm(self, name="weight", n_iterations=niter_bjorck) - self.register_forward_pre_hook(self._hook) + self.apply_lipschitz_factor() def vanilla_export(self) -> torch.nn.Linear: layer = torch.nn.Linear( @@ -142,7 +142,7 @@ def __init__( self.bias.data.fill_(0.0) frobenius_norm(self, name="weight", disjoint_neurons=disjoint_neurons) - self.register_forward_pre_hook(self._hook) + self.apply_lipschitz_factor() def vanilla_export(self): layer = torch.nn.Linear( diff --git a/deel/torchlip/modules/module.py b/deel/torchlip/modules/module.py index 61d59c7..eb93854 100644 --- a/deel/torchlip/modules/module.py +++ b/deel/torchlip/modules/module.py @@ -29,18 +29,32 @@ for condensation and vanilla exportation. """ import abc -from collections import OrderedDict import copy import logging import math +from collections import OrderedDict from typing import Any import numpy as np +import torch +import torch.nn as nn +import torch.nn.utils.parametrize as parametrize from torch.nn import Sequential as TorchSequential logger = logging.getLogger("deel.torchlip") +class _LipschitzCoefMultiplication(nn.Module): + """Parametrization module for lipschitz global coefficient multiplication.""" + + def __init__(self, coef: float): + super().__init__() + self._coef = coef + + def forward(self, weight: torch.Tensor) -> torch.Tensor: + return self._coef * weight + + class LipschitzModule(abc.ABC): """ This class allow to set lipschitz factor of a layer. Lipschitz layer must inherit @@ -58,8 +72,11 @@ class LipschitzModule(abc.ABC): def __init__(self, coefficient_lip: float = 1.0): self._coefficient_lip = coefficient_lip - def _hook(self, module, inputs): - setattr(module, "weight", getattr(module, "weight") * self._coefficient_lip) + def apply_lipschitz_factor(self): + """Multiply the layer weights by a lipschitz factor.""" + parametrize.register_parametrization( + self, "weight", _LipschitzCoefMultiplication(self._coefficient_lip) + ) @abc.abstractmethod def vanilla_export(self): diff --git a/deel/torchlip/utils/__init__.py b/deel/torchlip/utils/__init__.py index 1f9dd6c..7e559c5 100644 --- a/deel/torchlip/utils/__init__.py +++ b/deel/torchlip/utils/__init__.py @@ -39,6 +39,7 @@ from .lconv_norm import remove_lconv_norm from .sqrt_eps import sqrt_with_gradeps # noqa: F401 + DEFAULT_NITER_BJORCK = 15 DEFAULT_NITER_SPECTRAL = 3 DEFAULT_NITER_SPECTRAL_INIT = 10 diff --git a/deel/torchlip/utils/bjorck_norm.py b/deel/torchlip/utils/bjorck_norm.py index 72c6884..0d5dd75 100644 --- a/deel/torchlip/utils/bjorck_norm.py +++ b/deel/torchlip/utils/bjorck_norm.py @@ -24,53 +24,43 @@ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, # CRIAQ and ANITI - https://www.deel.ai/ # ===================================================================================== -from typing import Any -from typing import TypeVar - import torch +import torch.nn as nn +import torch.nn.utils.parametrize as parametrize from ..normalizers import bjorck_normalization from ..normalizers import DEFAULT_NITER_BJORCK -from .hook_norm import HookNorm - - -class BjorckNorm(HookNorm): - """ - Bjorck Normalization from https://arxiv.org/abs/1811.05381 - """ - n_iterations: int - def __init__(self, module: torch.nn.Module, name: str, n_iterations: int): - super().__init__(module, name) +class _BjorckNorm(nn.Module): + def __init__(self, weight: torch.Tensor, n_iterations: int) -> None: + super().__init__() self.n_iterations = n_iterations + self.register_buffer("_w_bjorck", weight.data) - def compute_weight(self, module: torch.nn.Module, inputs: Any) -> torch.Tensor: - return bjorck_normalization(self.weight(module), self.n_iterations) - - @staticmethod - def apply(module: torch.nn.Module, name: str, n_iterations: int) -> "BjorckNorm": - return BjorckNorm(module, name, n_iterations) - - -T_module = TypeVar("T_module", bound=torch.nn.Module) + def forward(self, weight: torch.Tensor) -> torch.Tensor: + if self.training: + w_bjorck = bjorck_normalization(weight, self.n_iterations) + self._w_bjorck = w_bjorck.data + else: + w_bjorck = self._w_bjorck + return w_bjorck def bjorck_norm( - module: T_module, name: str = "weight", n_iterations: int = DEFAULT_NITER_BJORCK -) -> T_module: + module: nn.Module, name: str = "weight", n_iterations: int = DEFAULT_NITER_BJORCK +) -> nn.Module: r""" Applies Bjorck normalization to a parameter in the given module. Bjorck normalization ensures that all eigen values of a vectors remain close or equal to one during training. If the dimension of the weight tensor is greater than 2, it is reshaped to 2D for iteration. - This is implemented via a hook that applies Bjorck normalization before every - ``forward()`` call. + This is implemented via a Bjorck normalization parametrization. .. note:: - It is recommended to use :py:func:`torch.nn.utils.spectral_norm` before - this hook to greatly reduce the number of iterations required. + It is recommended to use :py:func:`torch.nn.utils.parameterize.spectral_norm` + before this hook to greatly reduce the number of iterations required. See `Sorting out Lipschitz function approximation `_. @@ -92,11 +82,14 @@ def bjorck_norm( See Also: :py:func:`deel.torchlip.normalizers.bjorck_normalization` """ - BjorckNorm.apply(module, name, n_iterations) + weight = getattr(module, name, None) + parametrize.register_parametrization( + module, name, _BjorckNorm(weight, n_iterations) + ) return module -def remove_bjorck_norm(module: T_module, name: str = "weight") -> T_module: +def remove_bjorck_norm(module: nn.Module, name: str = "weight") -> nn.Module: r""" Removes the Bjorck normalization reparameterization from a module. @@ -108,10 +101,9 @@ def remove_bjorck_norm(module: T_module, name: str = "weight") -> T_module: >>> m = bjorck_norm(nn.Linear(20, 40)) >>> remove_bjorck_norm(m) """ - for k, hook in module._forward_pre_hooks.items(): - if isinstance(hook, BjorckNorm) and hook.name == name: - hook.remove(module) - del module._forward_pre_hooks[k] - return module - - raise ValueError("bjorck_norm of '{}' not found in {}".format(name, module)) + for key, m in module.parametrizations[name]._modules.items(): + if isinstance(m, _BjorckNorm): + if len(module.parametrizations["weight"]) == 1: + parametrize.remove_parametrizations(module, name) + else: + del module.parametrizations[name]._modules[key] diff --git a/deel/torchlip/utils/frobenius_norm.py b/deel/torchlip/utils/frobenius_norm.py index 6aa8a95..825a73d 100644 --- a/deel/torchlip/utils/frobenius_norm.py +++ b/deel/torchlip/utils/frobenius_norm.py @@ -24,36 +24,23 @@ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, # CRIAQ and ANITI - https://www.deel.ai/ # ===================================================================================== -from typing import Any -from typing import TypeVar - import torch - -from .hook_norm import HookNorm +import torch.nn as nn +import torch.nn.utils.parametrize as parametrize -class FrobeniusNorm(HookNorm): - def __init__(self, module: torch.nn.Module, name: str, disjoint_neurons: bool): - super().__init__(module, name) +class _FrobeniusNorm(nn.Module): + def __init__(self, disjoint_neurons: bool) -> None: + super().__init__() self.dim_norm = 1 if disjoint_neurons else None - def compute_weight(self, module: torch.nn.Module, inputs: Any) -> torch.Tensor: - w: torch.Tensor = self.weight(module) - return w / torch.norm(w, dim=self.dim_norm, keepdim=True) # type: ignore - - @staticmethod - def apply( - module: torch.nn.Module, name: str, disjoint_neurons: bool - ) -> "FrobeniusNorm": - return FrobeniusNorm(module, name, disjoint_neurons) - - -T_module = TypeVar("T_module", bound=torch.nn.Module) + def forward(self, weight: torch.Tensor) -> torch.Tensor: + return weight / torch.norm(weight, dim=self.dim_norm, keepdim=True) def frobenius_norm( - module: T_module, name: str = "weight", disjoint_neurons: bool = True -) -> T_module: + module: nn.Module, name: str = "weight", disjoint_neurons: bool = True +) -> nn.Module: r""" Applies Frobenius normalization to a parameter in the given module. @@ -78,11 +65,11 @@ def frobenius_norm( Linear(in_features=20, out_features=40, bias=True) """ - FrobeniusNorm.apply(module, name, disjoint_neurons) + parametrize.register_parametrization(module, name, _FrobeniusNorm(disjoint_neurons)) return module -def remove_frobenius_norm(module: T_module, name: str = "weight") -> T_module: +def remove_frobenius_norm(module: nn.Module, name: str = "weight") -> nn.Module: r""" Removes the Frobenius normalization reparameterization from a module. @@ -95,10 +82,9 @@ def remove_frobenius_norm(module: T_module, name: str = "weight") -> T_module: >>> m = frobenius_norm(nn.Linear(20, 40)) >>> remove_frobenius_norm(m) """ - for k, hook in module._forward_pre_hooks.items(): - if isinstance(hook, FrobeniusNorm) and hook.name == name: - hook.remove(module) - del module._forward_pre_hooks[k] - return module - - raise ValueError("frobenius_norm of '{}' not found in {}".format(name, module)) + for key, m in module.parametrizations[name]._modules.items(): + if isinstance(m, _FrobeniusNorm): + if len(module.parametrizations["weight"]) == 1: + parametrize.remove_parametrizations(module, name) + else: + del module.parametrizations[name]._modules[key] diff --git a/deel/torchlip/utils/hook_norm.py b/deel/torchlip/utils/hook_norm.py deleted file mode 100644 index fd508e3..0000000 --- a/deel/torchlip/utils/hook_norm.py +++ /dev/null @@ -1,103 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All -# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, -# CRIAQ and ANITI - https://www.deel.ai/ -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All -# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, -# CRIAQ and ANITI - https://www.deel.ai/ -# ===================================================================================== -from abc import abstractmethod -from typing import Any - -import inflection -import torch - - -class HookNorm: - - """ - Base class for pre-forward hook that modifies parameters of a module. The - constructor register the hook on the module, and sub-classes should only - implement the compute_weight method. - """ - - _name: str - _first: bool - - def __init__(self, module: torch.nn.Module, name: str = "weight"): - self._name = name - self._first = False - - if isinstance(getattr(module, name), torch.nn.Parameter): - weight = module._parameters[name] - self._first = True - delattr(module, name) - module.register_parameter(name + "_orig", weight) - setattr(module, name, weight.data) - - for k, hook in module._forward_pre_hooks.items(): - if isinstance(hook, type(self)) and hook.name == name: - raise RuntimeError( - "Cannot register two {} hooks on " - "the same parameter {}.".format( - inflection.underscore(type(self).__name__), name - ) - ) - - # Normalize weight before every forward(). - module.register_forward_pre_hook(self) - - def weight(self, module: torch.nn.Module) -> torch.Tensor: - """ - Returns: - The weight to apply the transformation to. This is not always the value - of the attribute corresponding to `name`. - """ - if self._first: - weight = getattr(module, self._name + "_orig") - else: - weight = getattr(module, self._name) - return weight # type: ignore - - @property - def name(self) -> str: - """ - Returns: - The name of the attribute that should be set by this hook. - """ - return self._name - - def remove(self, module: torch.nn.Module): - # If this was the first layer to hook, we reset the weights. - if self._first: - weight = getattr(module, self._name) - delattr(module, self._name) - module.register_parameter(self._name, torch.nn.Parameter(weight.detach())) - - @abstractmethod - def compute_weight(self, module: torch.nn.Module, inputs: Any) -> torch.Tensor: - """ - Transform the weight of the given module. - """ - pass - - def __call__(self, module: torch.nn.Conv2d, inputs: Any): - setattr(module, self.name, self.compute_weight(module, inputs)) diff --git a/deel/torchlip/utils/lconv_norm.py b/deel/torchlip/utils/lconv_norm.py index fd08b36..bb63031 100644 --- a/deel/torchlip/utils/lconv_norm.py +++ b/deel/torchlip/utils/lconv_norm.py @@ -29,8 +29,8 @@ import numpy as np import torch - -from .hook_norm import HookNorm +import torch.nn as nn +import torch.nn.utils.parametrize as parametrize def compute_lconv_coef( @@ -58,32 +58,68 @@ def compute_lconv_coef( return coefLip # type: ignore -class LConvNorm(HookNorm): +class _LConvNorm(nn.Module): + """Parametrization module for Lipschitz normalization.""" + + def __init__(self, lconv_coefficient: float) -> None: + super().__init__() + self.lconv_coefficient = lconv_coefficient + + def forward(self, weight: torch.Tensor) -> torch.Tensor: + return weight * self.lconv_coefficient + + +class LConvNormHook: """ Kernel normalization for Lipschitz convolution. Normalize weights based on input shape and kernel size, see https://arxiv.org/abs/2006.06520 """ - @staticmethod - def apply(module: torch.nn.Module) -> "LConvNorm": + def apply(self, module: torch.nn.Module, name: str = "weight") -> None: + self.name = name + self.coefficient = None if not isinstance(module, torch.nn.Conv2d): raise RuntimeError( "Can only apply lconv_norm hooks on 2D-convolutional layer." ) - return LConvNorm(module, "weight") + module.register_forward_pre_hook(self) - def compute_weight(self, module: torch.nn.Module, inputs: Any) -> torch.Tensor: - assert isinstance(module, torch.nn.Conv2d) + def __call__(self, module: torch.nn.Conv2d, inputs: Any): coefficient = compute_lconv_coef( module.kernel_size, inputs[0].shape[-4:], module.stride ) - return self.weight(module) * coefficient + # the parametrization is updated only if the coefficient has changed + if coefficient != self.coefficient: + if hasattr(module, "parametrizations"): + self.remove_parametrization(module) + parametrize.register_parametrization( + module, self.name, _LConvNorm(coefficient) + ) + self.coefficient = coefficient + + def remove_parametrization(self, module: nn.Module) -> nn.Module: + r""" + Removes the normalization reparameterization from a module. + + Args: + module: Containing module. + + Example: + >>> m = bjorck_norm(nn.Linear(20, 40)) + >>> remove_bjorck_norm(m) + """ + for key, m in module.parametrizations[self.name]._modules.items(): + if isinstance(m, _LConvNorm): + if len(module.parametrizations[self.name]) == 1: + parametrize.remove_parametrizations(module, self.name) + else: + del module.parametrizations[self.name]._modules[key] -def lconv_norm(module: torch.nn.Conv2d) -> torch.nn.Conv2d: +def lconv_norm(module: torch.nn.Conv2d, name: str = "weight") -> torch.nn.Conv2d: r""" Applies Lipschitz normalization to a kernel in the given convolutional. This is implemented via a hook that multiplies the kernel by a value computed @@ -94,6 +130,7 @@ def lconv_norm(module: torch.nn.Conv2d) -> torch.nn.Conv2d: Args: module: Containing module. + name: Name of weight parameter. Returns: The original module with the Lipschitz normalization hook. @@ -105,7 +142,7 @@ def lconv_norm(module: torch.nn.Conv2d) -> torch.nn.Conv2d: Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1)) """ - LConvNorm.apply(module) + LConvNormHook().apply(module, name) return module @@ -122,8 +159,8 @@ def remove_lconv_norm(module: torch.nn.Conv2d) -> torch.nn.Conv2d: >>> remove_lconv_norm(m) """ for k, hook in module._forward_pre_hooks.items(): - if isinstance(hook, LConvNorm): - hook.remove(module) + if isinstance(hook, LConvNormHook): + hook.remove_parametrization(module) del module._forward_pre_hooks[k] return module diff --git a/tests/test_downsampling.py b/tests/test_downsampling.py index 022c291..6551750 100644 --- a/tests/test_downsampling.py +++ b/tests/test_downsampling.py @@ -31,7 +31,6 @@ def test_invertible_downsample(): - # 1D input x = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8]]]) x = invertible_downsample(x, (2,)) diff --git a/tests/test_lip_hook_norm.py b/tests/test_parametrizations.py similarity index 68% rename from tests/test_lip_hook_norm.py rename to tests/test_parametrizations.py index efaebf6..2da0c3f 100644 --- a/tests/test_lip_hook_norm.py +++ b/tests/test_parametrizations.py @@ -28,50 +28,63 @@ import torch_testing as tt from deel.torchlip.normalizers import bjorck_normalization -from deel.torchlip.utils import bjorck_norm -from deel.torchlip.utils import frobenius_norm -from deel.torchlip.utils import lconv_norm -from deel.torchlip.utils import remove_bjorck_norm -from deel.torchlip.utils import remove_frobenius_norm -from deel.torchlip.utils import remove_lconv_norm +from deel.torchlip.utils.bjorck_norm import bjorck_norm +from deel.torchlip.utils.bjorck_norm import remove_bjorck_norm +from deel.torchlip.utils.frobenius_norm import frobenius_norm +from deel.torchlip.utils.frobenius_norm import remove_frobenius_norm from deel.torchlip.utils.lconv_norm import compute_lconv_coef +from deel.torchlip.utils.lconv_norm import lconv_norm +from deel.torchlip.utils.lconv_norm import remove_lconv_norm def test_bjorck_norm(): """ - test bjorck_norm hook implementation + test bjorck_norm parametrization implementation """ m = torch.nn.Linear(2, 2) torch.nn.init.orthogonal_(m.weight) w1 = bjorck_normalization(m.weight) + # bjorck norm parametrization bjorck_norm(m) + + # ensure that the original weight is the only torch parameter + assert isinstance(m.parametrizations.weight.original, torch.nn.Parameter) assert not isinstance(m.weight, torch.nn.Parameter) + # check the orthogonality of the weight x = torch.rand(2) m(x) tt.assert_equal(w1, m.weight) + # remove the parametrization remove_bjorck_norm(m) + assert not hasattr(m, "parametrizations") assert isinstance(m.weight, torch.nn.Parameter) tt.assert_equal(w1, m.weight) def test_frobenius_norm(): """ - test frobenius_norm hook implementation + test frobenius_norm parametrization implementation """ m = torch.nn.Linear(2, 2) torch.nn.init.uniform_(m.weight) w1 = m.weight / torch.norm(m.weight) + # frobenius norm parametrization frobenius_norm(m, disjoint_neurons=False) + + # ensure that the original weight is the only torch parameter + assert isinstance(m.parametrizations.weight.original, torch.nn.Parameter) assert not isinstance(m.weight, torch.nn.Parameter) + # check the orthogonality of the weight x = torch.rand(2) m(x) tt.assert_equal(w1, m.weight) + # remove the parametrization remove_frobenius_norm(m) assert isinstance(m.weight, torch.nn.Parameter) tt.assert_equal(w1, m.weight) @@ -79,11 +92,11 @@ def test_frobenius_norm(): def test_frobenius_norm_disjoint_neurons(): """ - Test `disjoint_neurons=True` parameter in frobenius_norm hook + Test `disjoint_neurons=True` argument in frobenius_norm parametrization """ m = torch.nn.Linear(in_features=5, out_features=3) - # Set hook and perform a forward pass to compute new weights + # Set parametrization and perform a forward pass to compute new weights frobenius_norm(m, disjoint_neurons=True) m(torch.rand(5)) @@ -94,19 +107,26 @@ def test_frobenius_norm_disjoint_neurons(): def test_lconv_norm(): """ - test lconv_norm hook implementation + test lconv_norm parametrization implementation """ m = torch.nn.Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)) torch.nn.init.orthogonal_(m.weight) w1 = m.weight * compute_lconv_coef(m.kernel_size, (1, 1, 5, 5), m.stride) + # lconv norm parametrization lconv_norm(m) + x = torch.rand(1, 1, 5, 5) + y = m(x) + + # ensure that the original weight is the only torch parameter + assert isinstance(m.parametrizations.weight.original, torch.nn.Parameter) assert not isinstance(m.weight, torch.nn.Parameter) - x = torch.rand(1, 1, 5, 5) - m(x) + # check the normalization of the weight tt.assert_equal(w1, m.weight) + tt.assert_equal(y, torch.nn.functional.conv2d(x, w1, bias=m.bias, stride=(1, 1))) + # remove the parametrization remove_lconv_norm(m) assert isinstance(m.weight, torch.nn.Parameter) tt.assert_equal(w1, m.weight) diff --git a/tests/test_torch_lip_layers.py b/tests/test_torch_lip_layers.py index fbd1a87..7067de9 100644 --- a/tests/test_torch_lip_layers.py +++ b/tests/test_torch_lip_layers.py @@ -290,10 +290,12 @@ def train_k_lip_model( empirical_lip_const = evaluate_lip_const(model=model, x=x, seed=42) # save the model model_checkpoint_path = os.path.join(logdir, "model.h5") - save(model, model_checkpoint_path) + save(model.state_dict(), model_checkpoint_path) del model - model = load(model_checkpoint_path) + # load model + model = generate_k_lip_model(layer_type, layer_params, k_lip_model) + model.load_state_dict(load(model_checkpoint_path)) model.eval() np.random.seed(42) manual_seed(42) diff --git a/tests/test_upsampling.py b/tests/test_upsampling.py index cc22053..c8eebb1 100644 --- a/tests/test_upsampling.py +++ b/tests/test_upsampling.py @@ -31,7 +31,6 @@ def test_invertible_upsample(): - # 1D input x = torch.tensor([[[1, 2], [3, 4], [5, 6], [7, 8]]]) x = invertible_upsample(x, (2,)) diff --git a/tests/test_vanilla_export.py b/tests/test_vanilla_export.py index dacdc1e..2e40a2a 100644 --- a/tests/test_vanilla_export.py +++ b/tests/test_vanilla_export.py @@ -24,10 +24,10 @@ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, # CRIAQ and ANITI - https://www.deel.ai/ # ===================================================================================== -from deel import torchlip - from collections import OrderedDict +from deel import torchlip + def get_named_model(): return torchlip.Sequential( From 80a75a0fc10956f3b4e180a91f1f7f91621c7850 Mon Sep 17 00:00:00 2001 From: Yannick Prudent Date: Tue, 14 Feb 2023 10:53:18 +0100 Subject: [PATCH 2/6] chore: update py version in workflows (3.6 deprec) --- .github/workflows/python-lints.yml | 2 +- .github/workflows/python-tests.yml | 2 +- setup.cfg | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/python-lints.yml b/.github/workflows/python-lints.yml index 1c40c0d..86cfdf8 100644 --- a/.github/workflows/python-lints.yml +++ b/.github/workflows/python-lints.yml @@ -8,7 +8,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: [3.6, 3.7, 3.8] + python-version: [3.7, 3.8, 3.9] steps: - uses: actions/checkout@v1 diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index ef64807..65d680f 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -8,7 +8,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: [3.6, 3.7, 3.8] + python-version: [3.7, 3.8, 3.9] steps: - uses: actions/checkout@v1 diff --git a/setup.cfg b/setup.cfg index 6ec0a38..7e46ebb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ ignore_missing_imports = True ignore_missing_imports = True [tox:tox] -envlist = py{36,37,38},py{36,37,38}-lint +envlist = py{37,38,39},py{37,38,39}-lint [testenv] pip_version = pip>=20 @@ -31,7 +31,7 @@ install_command = pip install --find-links https://download.pytorch.org/whl/torc commands = pytest tests -[testenv:py{36,37,38}-lint] +[testenv:py{37,38,39}-lint] skip_install = true deps = black From 5ab989ba08e0a8695407067558a71b9eb21cce13 Mon Sep 17 00:00:00 2001 From: Yannick Prudent Date: Thu, 30 Nov 2023 16:09:39 +0100 Subject: [PATCH 3/6] feat: add in-place vanilla model conversion --- deel/torchlip/modules/module.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/deel/torchlip/modules/module.py b/deel/torchlip/modules/module.py index eb93854..eaf0a73 100644 --- a/deel/torchlip/modules/module.py +++ b/deel/torchlip/modules/module.py @@ -44,6 +44,25 @@ logger = logging.getLogger("deel.torchlip") +def vanilla_model(model: nn.Module): + """Convert lipschitz modules into their non-lipschitz counterpart (for + instance, SpectralConv2d layers become Conv2d layers). + + Warning: This function modifies the model in-place. + + Args: + model (nn.Module): Lipschitz neural network + """ + for n, module in model.named_children(): + if len(list(module.children())) > 0: + # compound module, go inside it + vanilla_model(module) + + if isinstance(module, LipschitzModule): + # simple module + setattr(model, n, module.vanilla_export()) + + class _LipschitzCoefMultiplication(nn.Module): """Parametrization module for lipschitz global coefficient multiplication.""" From effd7e6ec6206f9d326ac7d02296a08d845062e0 Mon Sep 17 00:00:00 2001 From: Yannick Prudent Date: Mon, 15 Jan 2024 15:14:49 +0100 Subject: [PATCH 4/6] fix: asymptotic coef value for lconv It does not depends on input shape anymore --- deel/torchlip/utils/lconv_norm.py | 80 ++++++------------------------- tests/test_parametrizations.py | 2 +- 2 files changed, 16 insertions(+), 66 deletions(-) diff --git a/deel/torchlip/utils/lconv_norm.py b/deel/torchlip/utils/lconv_norm.py index bb63031..7ba0ef3 100644 --- a/deel/torchlip/utils/lconv_norm.py +++ b/deel/torchlip/utils/lconv_norm.py @@ -24,7 +24,6 @@ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, # CRIAQ and ANITI - https://www.deel.ai/ # ===================================================================================== -from typing import Any from typing import Tuple import numpy as np @@ -35,15 +34,15 @@ def compute_lconv_coef( kernel_size: Tuple[int, ...], - input_shape: Tuple[int, ...], + input_shape: Tuple[int, ...] = None, strides: Tuple[int, ...] = (1, 1), ) -> float: # See https://arxiv.org/abs/2006.06520 stride = np.prod(strides) k1, k2 = kernel_size - h, w = input_shape[-2:] - if stride == 1: + if stride == 1 and input_shape is not None: + h, w = input_shape[-2:] k1_div2 = (k1 - 1) / 2 k2_div2 = (k2 - 1) / 2 coefLip = np.sqrt( @@ -59,7 +58,7 @@ def compute_lconv_coef( class _LConvNorm(nn.Module): - """Parametrization module for Lipschitz normalization.""" + """Parametrization module for kernel normalization of lipschitz convolution.""" def __init__(self, lconv_coefficient: float) -> None: super().__init__() @@ -69,56 +68,6 @@ def forward(self, weight: torch.Tensor) -> torch.Tensor: return weight * self.lconv_coefficient -class LConvNormHook: - - """ - Kernel normalization for Lipschitz convolution. Normalize weights - based on input shape and kernel size, see https://arxiv.org/abs/2006.06520 - """ - - def apply(self, module: torch.nn.Module, name: str = "weight") -> None: - self.name = name - self.coefficient = None - - if not isinstance(module, torch.nn.Conv2d): - raise RuntimeError( - "Can only apply lconv_norm hooks on 2D-convolutional layer." - ) - - module.register_forward_pre_hook(self) - - def __call__(self, module: torch.nn.Conv2d, inputs: Any): - coefficient = compute_lconv_coef( - module.kernel_size, inputs[0].shape[-4:], module.stride - ) - # the parametrization is updated only if the coefficient has changed - if coefficient != self.coefficient: - if hasattr(module, "parametrizations"): - self.remove_parametrization(module) - parametrize.register_parametrization( - module, self.name, _LConvNorm(coefficient) - ) - self.coefficient = coefficient - - def remove_parametrization(self, module: nn.Module) -> nn.Module: - r""" - Removes the normalization reparameterization from a module. - - Args: - module: Containing module. - - Example: - >>> m = bjorck_norm(nn.Linear(20, 40)) - >>> remove_bjorck_norm(m) - """ - for key, m in module.parametrizations[self.name]._modules.items(): - if isinstance(m, _LConvNorm): - if len(module.parametrizations[self.name]) == 1: - parametrize.remove_parametrizations(module, self.name) - else: - del module.parametrizations[self.name]._modules[key] - - def lconv_norm(module: torch.nn.Conv2d, name: str = "weight") -> torch.nn.Conv2d: r""" Applies Lipschitz normalization to a kernel in the given convolutional. @@ -142,26 +91,27 @@ def lconv_norm(module: torch.nn.Conv2d, name: str = "weight") -> torch.nn.Conv2d Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1)) """ - LConvNormHook().apply(module, name) + coefficient = compute_lconv_coef(module.kernel_size, None, module.stride) + parametrize.register_parametrization(module, name, _LConvNorm(coefficient)) return module -def remove_lconv_norm(module: torch.nn.Conv2d) -> torch.nn.Conv2d: +def remove_lconv_norm(module: torch.nn.Conv2d, name: str = "weight") -> torch.nn.Conv2d: r""" - Removes the Lipschitz normalization hook from a module. + Removes the normalization parametrization for lipschitz convolution from a module. Args: module: Containing module. + name: Name of weight parameter. Example: >>> m = lconv_norm(nn.Conv2d(16, 16, (3, 3))) >>> remove_lconv_norm(m) """ - for k, hook in module._forward_pre_hooks.items(): - if isinstance(hook, LConvNormHook): - hook.remove_parametrization(module) - del module._forward_pre_hooks[k] - return module - - raise ValueError("lconv_norm not found in {}".format(module)) + for key, m in module.parametrizations[name]._modules.items(): + if isinstance(m, _LConvNorm): + if len(module.parametrizations[name]) == 1: + parametrize.remove_parametrizations(module, name) + else: + del module.parametrizations[name]._modules[key] diff --git a/tests/test_parametrizations.py b/tests/test_parametrizations.py index 2da0c3f..29b995c 100644 --- a/tests/test_parametrizations.py +++ b/tests/test_parametrizations.py @@ -111,7 +111,7 @@ def test_lconv_norm(): """ m = torch.nn.Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)) torch.nn.init.orthogonal_(m.weight) - w1 = m.weight * compute_lconv_coef(m.kernel_size, (1, 1, 5, 5), m.stride) + w1 = m.weight * compute_lconv_coef(m.kernel_size, None, m.stride) # lconv norm parametrization lconv_norm(m) From 87f78e9392b00b10d595ff385eeb65c91885b155 Mon Sep 17 00:00:00 2001 From: Yannick Prudent Date: Mon, 5 Feb 2024 17:22:28 +0100 Subject: [PATCH 5/6] chore: vanilla_model added in deel/torchlip/__init__.py --- deel/torchlip/__init__.py | 1 + deel/torchlip/modules/__init__.py | 1 + 2 files changed, 2 insertions(+) diff --git a/deel/torchlip/__init__.py b/deel/torchlip/__init__.py index be2c5b0..72c0e52 100644 --- a/deel/torchlip/__init__.py +++ b/deel/torchlip/__init__.py @@ -55,4 +55,5 @@ "Sequential", "SpectralConv2d", "SpectralLinear", + "vanilla_model", ] diff --git a/deel/torchlip/modules/__init__.py b/deel/torchlip/modules/__init__.py index b2653d3..7f730d4 100644 --- a/deel/torchlip/modules/__init__.py +++ b/deel/torchlip/modules/__init__.py @@ -64,6 +64,7 @@ from .loss import NegKRLoss from .module import LipschitzModule from .module import Sequential +from .module import vanilla_model from .pooling import ScaledAdaptiveAvgPool2d from .pooling import ScaledAvgPool2d from .pooling import ScaledL2NormPool2d From b8a0b3d5ea0fc0d173140301fbb240253b705527 Mon Sep 17 00:00:00 2001 From: Franck Mamalet Date: Mon, 15 Apr 2024 07:59:13 +0200 Subject: [PATCH 6/6] remove parametrization for lipschitz param when coef_lip==1.0 --- deel/torchlip/modules/module.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deel/torchlip/modules/module.py b/deel/torchlip/modules/module.py index eaf0a73..ea35fad 100644 --- a/deel/torchlip/modules/module.py +++ b/deel/torchlip/modules/module.py @@ -93,6 +93,8 @@ def __init__(self, coefficient_lip: float = 1.0): def apply_lipschitz_factor(self): """Multiply the layer weights by a lipschitz factor.""" + if self._coefficient_lip == 1.0: + return parametrize.register_parametrization( self, "weight", _LipschitzCoefMultiplication(self._coefficient_lip) )