Skip to content

Commit 4580cff

Browse files
committed
Add initial tests.
1 parent bd63d41 commit 4580cff

File tree

3 files changed

+129
-0
lines changed

3 files changed

+129
-0
lines changed

tests/test_extract_irreps.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
3+
import e3nn
4+
import torch
5+
6+
from e3tools.nn import ExtractIrreps
7+
8+
9+
def test_extract_irreps():
10+
irreps_in = e3nn.o3.Irreps("0e + 1o + 2e")
11+
input = torch.as_tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.])
12+
assert input.shape[-1] == irreps_in.dim
13+
14+
15+
layer = ExtractIrreps(irreps_in, "0e")
16+
output = layer(input)
17+
assert torch.allclose(output, torch.as_tensor([1.]))
18+
19+
20+
layer = ExtractIrreps(irreps_in, "1o")
21+
output = layer(input)
22+
assert torch.allclose(output, torch.as_tensor([2., 3., 4.]))
23+
24+
layer = ExtractIrreps(irreps_in, "2e")
25+
output = layer(input)
26+
assert torch.allclose(output, torch.as_tensor([5., 6., 7., 8., 9.]))
27+

tests/test_pack_unpack.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import pytest
2+
3+
import e3nn
4+
import torch
5+
6+
from e3tools.nn import AxisToMul, MulToAxis
7+
8+
9+
10+
@pytest.mark.parametrize(
11+
"irreps_in, factor",
12+
zip(
13+
["0e + 1o", "8x0e + 8x1o + 8x2e", "8x0e + 8x1o + 8x2e", "3x1o + 3x2o"],
14+
[1, 2, 4, 3],
15+
),
16+
)
17+
def test_axis_to_mul_shape(irreps_in: str, factor: int, batch_size: int = 5):
18+
irreps_in = e3nn.o3.Irreps(irreps_in)
19+
layer = AxisToMul(irreps_in, factor)
20+
assert layer.irreps_in == irreps_in
21+
22+
input = irreps_in.randn(batch_size, factor, -1)
23+
output = layer(input)
24+
25+
assert output.shape == (batch_size, factor * irreps_in.dim)
26+
27+
28+
@pytest.mark.parametrize(
29+
"irreps_in, factor",
30+
zip(
31+
["0e + 1o", "8x0e + 8x1o + 8x2e", "8x0e + 8x1o + 8x2e", "3x1o + 3x2o"],
32+
[1, 2, 4, 3],
33+
),
34+
)
35+
def test_mul_to_axis_shape(irreps_in: str, factor: int, batch_size: int = 5):
36+
irreps_in = e3nn.o3.Irreps(irreps_in)
37+
layer = MulToAxis(irreps_in, factor)
38+
assert layer.irreps_in == irreps_in
39+
40+
input = irreps_in.randn(batch_size, -1)
41+
output = layer(input)
42+
43+
assert output.shape == (batch_size, factor, irreps_in.dim // factor)
44+
45+
46+
47+
@pytest.mark.parametrize(
48+
"irreps_in, factor",
49+
zip(
50+
["0e + 1o", "8x0e + 8x1o + 8x2e", "8x0e + 8x1o + 8x2e", "3x1o + 3x2o"],
51+
[1, 2, 4, 3],
52+
),
53+
)
54+
def test_inverse(irreps_in: str, factor: int, batch_size: int = 5):
55+
irreps_in = e3nn.o3.Irreps(irreps_in)
56+
layer = MulToAxis(irreps_in, factor)
57+
inv_layer = AxisToMul(layer.irreps_out, factor)
58+
59+
assert layer.irreps_in == irreps_in
60+
assert inv_layer.irreps_out == irreps_in
61+
62+
input = irreps_in.randn(batch_size, -1)
63+
output = layer(input)
64+
recovered = inv_layer(output)
65+
66+
assert torch.allclose(input, recovered)

tests/test_scaling.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pytest
2+
3+
import e3nn
4+
import torch
5+
6+
from e3tools.nn import ScaleIrreps
7+
8+
9+
@pytest.mark.parametrize("irreps_in", ["0e + 1o", "0e + 1o + 2e", "3x1o + 2x2o"])
10+
def test_scale_irreps_by_one(irreps_in):
11+
irreps_in = e3nn.o3.Irreps(irreps_in)
12+
layer = ScaleIrreps(irreps_in)
13+
assert layer.irreps_in == irreps_in
14+
assert layer.irreps_out == irreps_in
15+
16+
input = irreps_in.randn(-1)
17+
weight = torch.ones(irreps_in.num_irreps)
18+
output = layer(input, weight)
19+
20+
assert torch.allclose(input, output)
21+
22+
23+
@pytest.mark.parametrize("irreps_in", ["0e + 1o", "0e + 1o + 2e", "3x1o + 2x2o"])
24+
def test_scale_irreps_random(irreps_in: str):
25+
irreps_in = e3nn.o3.Irreps(irreps_in)
26+
layer = ScaleIrreps(irreps_in)
27+
assert layer.irreps_in == irreps_in
28+
assert layer.irreps_out == irreps_in
29+
30+
input = irreps_in.randn(-1)
31+
weight = torch.randn(irreps_in.num_irreps)
32+
output = layer(input, weight)
33+
34+
norm = e3nn.o3.Norm(irreps_in)
35+
factor = norm(output) / norm(input)
36+
assert torch.allclose(factor, torch.abs(weight))

0 commit comments

Comments
 (0)