Skip to content

Commit

Permalink
Improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Feb 21, 2025
1 parent 8e377c5 commit 3d3f1db
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions scico/test/optimize/test_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_pgm(self):
np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3)

def test_pgm_saveload(self):
maxiter = 10
maxiter = 5
A = linop.MatrixOperator(self.Amx)
L0 = 1.05 * linop.power_iteration(A.T @ A)[0]
loss_ = loss.SquaredL2Loss(y=self.y, A=A)
Expand All @@ -56,9 +56,14 @@ def test_pgm_saveload(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "pgm.npz")
pgm0.save_state(path)
pgm0.solve()
h0 = pgm0.history()
pgm1 = PGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y))
pgm1.load_state(path)
np.testing.assert_allclose(pgm0.x, pgm1.x, rtol=1e-7)
pgm1.solve()
h1 = pgm1.history()
np.testing.assert_allclose(pgm0.minimizer(), pgm1.minimizer(), rtol=1e-6)
assert np.abs(h0[-1].Objective - h1[-1].Objective) < 1e-6

def test_pgm_isfinite(self):
maxiter = 5
Expand All @@ -83,7 +88,7 @@ def test_accelerated_pgm(self):
np.testing.assert_allclose(self.grdA(x), self.grdb, rtol=5e-3)

def test_accelerated_pgm_saveload(self):
maxiter = 10
maxiter = 5
A = linop.MatrixOperator(self.Amx)
L0 = 1.05 * linop.power_iteration(A.T @ A)[0]
loss_ = loss.SquaredL2Loss(y=self.y, A=A)
Expand All @@ -93,11 +98,14 @@ def test_accelerated_pgm_saveload(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "pgm.npz")
apgm0.save_state(path)
apgm0.solve()
h0 = apgm0.history()
apgm1 = AcceleratedPGM(f=loss_, g=g, L0=L0, maxiter=maxiter, x0=A.adj(self.y))
apgm1.load_state(path)
np.testing.assert_allclose(apgm0.x, apgm1.x, rtol=1e-7)
np.testing.assert_allclose(apgm0.v, apgm1.v, rtol=1e-7)
np.testing.assert_allclose(apgm0.t, apgm1.t, rtol=1e-7)
apgm1.solve()
h1 = apgm1.history()
np.testing.assert_allclose(apgm0.minimizer(), apgm1.minimizer(), rtol=1e-6)
assert np.abs(h0[-1].Objective - h1[-1].Objective) < 1e-6

def test_accelerated_pgm_isfinite(self):
maxiter = 5
Expand Down

0 comments on commit 3d3f1db

Please sign in to comment.