Skip to content

Commit f8942af

Browse files
Support old mlptrain npz files (#373)
* Patch ValueArray pickle protocol to allow loading of old npz files * Add tests * Add changelog entry * Don't fail fast * Format * revert fail-fast --------- Co-authored-by: Tom Young <39765193+t-young31@users.noreply.github.com>
1 parent de8ed4a commit f8942af

File tree

4 files changed

+34
-2
lines changed

4 files changed

+34
-2
lines changed

autode/values.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -664,8 +664,15 @@ def __reduce__(self):
664664
)
665665

666666
def __setstate__(self, state, *args, **kwargs):
667-
self.__dict__.update(state[-1])
668-
super().__setstate__(state[:-1], *args, **kwargs)
667+
"""Extend default pickling protocol to include extra attributes from ValueArray"""
668+
669+
try:
670+
self.__dict__.update(state[-1])
671+
super().__setstate__(state[:-1], *args, **kwargs)
672+
except TypeError:
673+
# This is a fallback so we can load old .npz files in mlptrain, see:
674+
# https://github.com/duartegroup/autodE/issues/372
675+
super().__setstate__(state, *args, **kwargs)
669676

670677
def to(self, units) -> Any:
671678
"""

doc/changelog.rst

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Functionality improvements
1313
Bug Fixes
1414
*********
1515
- Fixes coordinate extraction in some G16 output files
16+
- Fixes loading old mlptrain .npz files
1617

1718
Usability improvements/Changes
1819
******************************

tests/data/old_mlptrain.npz

17.2 KB
Binary file not shown.

tests/test_values.py

+24
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import numpy as np
24
import pytest
35

@@ -12,6 +14,8 @@
1214
_to,
1315
)
1416

17+
here = os.path.dirname(os.path.abspath(__file__))
18+
1519

1620
class TmpValues(ValueArray):
1721
implemented_units = [ha, ev]
@@ -123,3 +127,23 @@ def test_force_constant():
123127

124128
# should be able to convert to Ha/a0^2 without any problems
125129
_ = fc.to("Ha/a0^2")
130+
131+
132+
def test_pickle():
133+
"""Regression test for https://github.com/duartegroup/autodE/issues/221"""
134+
import pickle
135+
136+
x = Gradient([[1.0, 1.0, 1.0]], units=ha_per_ang)
137+
pickled_x = pickle.dumps(x, pickle.HIGHEST_PROTOCOL)
138+
unpickled_x = pickle.loads(pickled_x)
139+
assert unpickled_x.units == ha_per_ang
140+
assert unpickled_x == x
141+
142+
143+
def test_load_old_mlptrain_npz():
144+
"""Regression test for https://github.com/duartegroup/autodE/issues/372"""
145+
import numpy as np
146+
147+
npz_file = os.path.join(here, "data", "old_mlptrain.npz")
148+
data = np.load(npz_file, allow_pickle=True)
149+
assert data["F_true"] is not None

0 commit comments

Comments
 (0)