Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: change Lipschitz cst estimation to Jacobian method #95

Merged
merged 2 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 26 additions & 31 deletions deel/lip/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,68 +9,63 @@
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras import backend as K


def evaluate_lip_const_gen(
model: Model,
generator: Generator[Tuple[np.ndarray, np.ndarray], Any, None],
eps=1e-4,
seed=None,
model: Model, generator: Generator[Tuple[np.ndarray, np.ndarray], Any, None]
):
"""
Evaluate the Lipschitz constant of a model, with the naive method.
Please note that the estimation of the lipschitz constant is done locally around
input sample. This may not correctly estimate the behaviour in the whole domain.
Evaluate the Lipschitz constant of a model, using the Jacobian of the model.
Please note that the estimation of the Lipschitz constant is done locally around
input samples. This may not correctly estimate the behaviour in the whole domain.
The computation might also be inaccurate in high dimensional space.

This is the generator version of evaluate_lip_const.

Args:
model: built keras model used to make predictions
generator: used to select datapoints where to compute the lipschitz constant
eps (float): magnitude of noise to add to input in order to compute the constant
seed (int): seed used when generating the noise ( can be set to None )

Returns:
float: the empirically evaluated lipschitz constant.

"""
x, _ = generator.send(None)
return evaluate_lip_const(model, x, eps, seed=seed)
return evaluate_lip_const(model, x)


def evaluate_lip_const(model: Model, x, eps=1e-4, seed=None):
def evaluate_lip_const(model: Model, x):
"""
Evaluate the Lipschitz constant of a model, with the naive method.
Evaluate the Lipschitz constant of a model, using the Jacobian of the model.
Please note that the estimation of the lipschitz constant is done locally around
input sample. This may not correctly estimate the behaviour in the whole domain.
input samples. This may not correctly estimate the behaviour in the whole domain.

Args:
model: built keras model used to make predictions
x: inputs used to compute the lipschitz constant
eps (float): magnitude of noise to add to input in order to compute the constant
seed (int): seed used when generating the noise ( can be set to None )

Returns:
float: the empirically evaluated lipschitz constant. The computation might also
float: the empirically evaluated Lipschitz constant. The computation might also
be inaccurate in high dimensional space.

"""
y_pred = model.predict(x)
# x = np.repeat(x, 100, 0)
# y_pred = np.repeat(y_pred, 100, 0)
x_var = x + K.random_uniform(
shape=x.shape, minval=eps * 0.25, maxval=eps, seed=seed
)
y_pred_var = model.predict(x_var)
dx = x - x_var
dfx = y_pred - y_pred_var
ndx = K.sqrt(K.sum(K.square(dx), axis=range(1, len(x.shape))))
ndfx = K.sqrt(K.sum(K.square(dfx), axis=range(1, len(y_pred.shape))))
lip_cst = K.max(ndfx / ndx)
print(f"lip cst: {lip_cst:.3f}")
return lip_cst
batch_size = x.shape[0]
x = tf.constant(x, dtype=model.input.dtype)

# Get the jacobians of the model w.r.t. the inputs
with tf.GradientTape() as tape:
tape.watch(x)
y_pred = model(x, training=False)
batch_jacobian = tape.batch_jacobian(y_pred, x)

# Reshape the jacobians (in case of multi-dimensional input/output like in conv)
xdim = tf.reduce_prod(x.shape[1:])
ydim = tf.reduce_prod(y_pred.shape[1:])
batch_jacobian = tf.reshape(batch_jacobian, (batch_size, ydim, xdim))

# Compute the spectral norm of the jacobians and return the maximum
b = tf.norm(batch_jacobian, ord=2, axis=[-2, -1]).numpy()
return tf.reduce_max(b)


def _padding_circular(x, circular_paddings):
Expand Down
3 changes: 1 addition & 2 deletions tests/test_compute_layer_sv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
# =====================================================================================
"""Tests for singular value computation (in compute_layer_sv.py)
"""
"""Tests for singular value computation (in compute_layer_sv.py)"""
import os
import pprint
import unittest
Expand Down
4 changes: 2 additions & 2 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def train_k_lip_model(
linear_generator(batch_size, input_shape, kernel),
steps=10,
)
empirical_lip_const = evaluate_lip_const(model=model, x=x, seed=42)
empirical_lip_const = evaluate_lip_const(model=model, x=x)
# save the model
model_checkpoint_path = os.path.join(logdir, "model.keras")
model.save(model_checkpoint_path, overwrite=True)
Expand All @@ -237,7 +237,7 @@ def train_k_lip_model(
linear_generator(batch_size, input_shape, kernel),
steps=10,
)
from_empirical_lip_const = evaluate_lip_const(model=model, x=x, seed=42)
from_empirical_lip_const = evaluate_lip_const(model=model, x=x)
# log metrics
file_writer = tf.summary.create_file_writer(os.path.join(logdir, "metrics"))
file_writer.set_as_default()
Expand Down