Skip to content

Commit

Permalink
add import functorch for torch version <2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
franckma31 committed Feb 17, 2025
1 parent d5c1e44 commit 965ade3
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions deel/torchlip/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
from typing import Optional

import torch
import torch.func as tfc

if torch.__version__.startswith("1."):
import functorch as tfc
else:
import torch.func as tfc

from .bjorck_norm import bjorck_norm
from .bjorck_norm import remove_bjorck_norm
Expand Down Expand Up @@ -74,7 +78,6 @@ def model_func(x):
# Reshape the Jacobian to match the desired shape
batch_size = x.shape[0]
xdim = torch.prod(torch.tensor(x.shape[1:])).item()
print(batch_jacobian.shape, x.shape)
batch_jacobian = batch_jacobian.view(batch_size, -1, xdim)

# Compute singular values and check Lipschitz property
Expand Down

0 comments on commit 965ade3

Please sign in to comment.