diff --git a/lib/gpt/core/foundation/tensor.py b/lib/gpt/core/foundation/tensor.py index 42bd45196..9d84919eb 100644 --- a/lib/gpt/core/foundation/tensor.py +++ b/lib/gpt/core/foundation/tensor.py @@ -48,6 +48,24 @@ def component_simple_map(operator, numpy_operator, extra_params, first, second): return res +def identity(t): + e = gpt.tensor(t.otype) + if len(e.array.shape) == 2: + e.array = numpy.eye(dtype=e.array.dtype, N=e.array.shape[0]) + elif len(e.array.shape) == 4: + n1 = e.array.shape[0] + n2 = e.array.shape[2] + for i in range(n1): + for j in range(n1): + if i == j: + e.array[i, j] = numpy.eye(dtype=e.array.dtype, N=n2) + else: + e.array[i, j] = numpy.zeros(dtype=e.array.dtype, shape=(n2, n2)) + else: + raise Exception(f"Unknown shape of tensor.identity {e.array.shape}") + return e + + def adj(l): if l.transposable(): return l.adj() diff --git a/lib/gpt/core/matrix/inv.py b/lib/gpt/core/matrix/inv.py index 3f175e1f9..d69df2e4d 100644 --- a/lib/gpt/core/matrix/inv.py +++ b/lib/gpt/core/matrix/inv.py @@ -17,10 +17,31 @@ # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # import gpt, cgpt +import numpy as np + + +def deflate(A): + if len(A.shape) == 4: + n1 = A.shape[0] + n2 = A.shape[2] + return n1, n2, np.swapaxes(A, 1, 2).reshape(n1 * n2, n1 * n2) + return None, None, A + + +def inflate(n1, n2, A): + if n1 is None: + return A + return np.swapaxes(A.reshape(n1, n2, n1, n2), 1, 2) def inv(A): A = gpt.eval(A) + + if isinstance(A, gpt.tensor): + n1, n2, a = deflate(A.array) + a = inflate(n1, n2, np.linalg.inv(a)) + return gpt.tensor(a, A.otype) + assert isinstance(A, gpt.lattice) to_list = gpt.util.to_list diff --git a/tests/core/matrix.py b/tests/core/matrix.py index bb69f7a75..145d75320 100755 --- a/tests/core/matrix.py +++ b/tests/core/matrix.py @@ -89,6 +89,11 @@ def mod0p1(x): g.message(f"test M*M^-1 = 1 for {m.otype.__name__}: {eps2}") assert eps2 < eps**2 + eps2 = g.norm2( + g.matrix.inv(m[0, 0, 0, 0]) * m[0, 0, 0, 0] - g.identity(m[0, 0, 0, 0]) + ) / g.norm2(g.identity(m[0, 0, 0, 0])) + assert eps2 < eps**2 * 10 + m2 = g.matrix.exp(g.matrix.log(m)) eps2 = g.norm2(m - m2) / g.norm2(m) g.message(f"exp(log(m)) == m: {eps2}")