Skip to content

Commit

Permalink
inverse tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Oct 30, 2024
1 parent 073edfe commit 3fdc476
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 0 deletions.
18 changes: 18 additions & 0 deletions lib/gpt/core/foundation/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ def component_simple_map(operator, numpy_operator, extra_params, first, second):
return res


def identity(t):
e = gpt.tensor(t.otype)
if len(e.array.shape) == 2:
e.array = numpy.eye(dtype=e.array.dtype, N=e.array.shape[0])
elif len(e.array.shape) == 4:
n1 = e.array.shape[0]
n2 = e.array.shape[2]
for i in range(n1):
for j in range(n1):
if i == j:
e.array[i, j] = numpy.eye(dtype=e.array.dtype, N=n2)
else:
e.array[i, j] = numpy.zeros(dtype=e.array.dtype, shape=(n2, n2))
else:
raise Exception(f"Unknown shape of tensor.identity {e.array.shape}")
return e


def adj(l):
if l.transposable():
return l.adj()
Expand Down
21 changes: 21 additions & 0 deletions lib/gpt/core/matrix/inv.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,31 @@
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import gpt, cgpt
import numpy as np


def deflate(A):
if len(A.shape) == 4:
n1 = A.shape[0]
n2 = A.shape[2]
return n1, n2, np.swapaxes(A, 1, 2).reshape(n1 * n2, n1 * n2)
return None, None, A


def inflate(n1, n2, A):
if n1 is None:
return A
return np.swapaxes(A.reshape(n1, n2, n1, n2), 1, 2)


def inv(A):
A = gpt.eval(A)

if isinstance(A, gpt.tensor):
n1, n2, a = deflate(A.array)
a = inflate(n1, n2, np.linalg.inv(a))
return gpt.tensor(a, A.otype)

assert isinstance(A, gpt.lattice)

to_list = gpt.util.to_list
Expand Down
5 changes: 5 additions & 0 deletions tests/core/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def mod0p1(x):
g.message(f"test M*M^-1 = 1 for {m.otype.__name__}: {eps2}")
assert eps2 < eps**2

eps2 = g.norm2(
g.matrix.inv(m[0, 0, 0, 0]) * m[0, 0, 0, 0] - g.identity(m[0, 0, 0, 0])
) / g.norm2(g.identity(m[0, 0, 0, 0]))
assert eps2 < eps**2 * 10

m2 = g.matrix.exp(g.matrix.log(m))
eps2 = g.norm2(m - m2) / g.norm2(m)
g.message(f"exp(log(m)) == m: {eps2}")
Expand Down

0 comments on commit 3fdc476

Please sign in to comment.