diff --git a/deel/lip/utils.py b/deel/lip/utils.py index dc6281ef..8e9d50eb 100644 --- a/deel/lip/utils.py +++ b/deel/lip/utils.py @@ -9,19 +9,15 @@ import numpy as np import tensorflow as tf from tensorflow.keras import Model -from tensorflow.keras import backend as K def evaluate_lip_const_gen( - model: Model, - generator: Generator[Tuple[np.ndarray, np.ndarray], Any, None], - eps=1e-4, - seed=None, + model: Model, generator: Generator[Tuple[np.ndarray, np.ndarray], Any, None] ): """ - Evaluate the Lipschitz constant of a model, with the naive method. - Please note that the estimation of the lipschitz constant is done locally around - input sample. This may not correctly estimate the behaviour in the whole domain. + Evaluate the Lipschitz constant of a model, using the Jacobian of the model. + Please note that the estimation of the Lipschitz constant is done locally around + input samples. This may not correctly estimate the behaviour in the whole domain. The computation might also be inaccurate in high dimensional space. This is the generator version of evaluate_lip_const. @@ -29,48 +25,47 @@ def evaluate_lip_const_gen( Args: model: built keras model used to make predictions generator: used to select datapoints where to compute the lipschitz constant - eps (float): magnitude of noise to add to input in order to compute the constant - seed (int): seed used when generating the noise ( can be set to None ) Returns: float: the empirically evaluated lipschitz constant. """ x, _ = generator.send(None) - return evaluate_lip_const(model, x, eps, seed=seed) + return evaluate_lip_const(model, x) -def evaluate_lip_const(model: Model, x, eps=1e-4, seed=None): +def evaluate_lip_const(model: Model, x): """ - Evaluate the Lipschitz constant of a model, with the naive method. + Evaluate the Lipschitz constant of a model, using the Jacobian of the model. Please note that the estimation of the lipschitz constant is done locally around - input sample. This may not correctly estimate the behaviour in the whole domain. + input samples. This may not correctly estimate the behaviour in the whole domain. Args: model: built keras model used to make predictions x: inputs used to compute the lipschitz constant - eps (float): magnitude of noise to add to input in order to compute the constant - seed (int): seed used when generating the noise ( can be set to None ) Returns: - float: the empirically evaluated lipschitz constant. The computation might also + float: the empirically evaluated Lipschitz constant. The computation might also be inaccurate in high dimensional space. """ - y_pred = model.predict(x) - # x = np.repeat(x, 100, 0) - # y_pred = np.repeat(y_pred, 100, 0) - x_var = x + K.random_uniform( - shape=x.shape, minval=eps * 0.25, maxval=eps, seed=seed - ) - y_pred_var = model.predict(x_var) - dx = x - x_var - dfx = y_pred - y_pred_var - ndx = K.sqrt(K.sum(K.square(dx), axis=range(1, len(x.shape)))) - ndfx = K.sqrt(K.sum(K.square(dfx), axis=range(1, len(y_pred.shape)))) - lip_cst = K.max(ndfx / ndx) - print(f"lip cst: {lip_cst:.3f}") - return lip_cst + batch_size = x.shape[0] + x = tf.constant(x, dtype=model.input.dtype) + + # Get the jacobians of the model w.r.t. the inputs + with tf.GradientTape() as tape: + tape.watch(x) + y_pred = model(x, training=False) + batch_jacobian = tape.batch_jacobian(y_pred, x) + + # Reshape the jacobians (in case of multi-dimensional input/output like in conv) + xdim = tf.reduce_prod(x.shape[1:]) + ydim = tf.reduce_prod(y_pred.shape[1:]) + batch_jacobian = tf.reshape(batch_jacobian, (batch_size, ydim, xdim)) + + # Compute the spectral norm of the jacobians and return the maximum + b = tf.norm(batch_jacobian, ord=2, axis=[-2, -1]).numpy() + return tf.reduce_max(b) def _padding_circular(x, circular_paddings): diff --git a/tests/test_compute_layer_sv.py b/tests/test_compute_layer_sv.py index 4efec5b5..0ffa5361 100644 --- a/tests/test_compute_layer_sv.py +++ b/tests/test_compute_layer_sv.py @@ -2,8 +2,7 @@ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, # CRIAQ and ANITI - https://www.deel.ai/ # ===================================================================================== -"""Tests for singular value computation (in compute_layer_sv.py) -""" +"""Tests for singular value computation (in compute_layer_sv.py)""" import os import pprint import unittest diff --git a/tests/test_layers.py b/tests/test_layers.py index 46c3cfc4..40639217 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -224,7 +224,7 @@ def train_k_lip_model( linear_generator(batch_size, input_shape, kernel), steps=10, ) - empirical_lip_const = evaluate_lip_const(model=model, x=x, seed=42) + empirical_lip_const = evaluate_lip_const(model=model, x=x) # save the model model_checkpoint_path = os.path.join(logdir, "model.keras") model.save(model_checkpoint_path, overwrite=True) @@ -237,7 +237,7 @@ def train_k_lip_model( linear_generator(batch_size, input_shape, kernel), steps=10, ) - from_empirical_lip_const = evaluate_lip_const(model=model, x=x, seed=42) + from_empirical_lip_const = evaluate_lip_const(model=model, x=x) # log metrics file_writer = tf.summary.create_file_writer(os.path.join(logdir, "metrics")) file_writer.set_as_default()