Skip to content

Commit e79849b

Browse files
committed
add basic tests + standalone scatter/radius_graph
1 parent 434ac43 commit e79849b

9 files changed

+416
-2
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__

pyproject.toml

+9
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,18 @@ authors = [
99
requires-python = ">=3.11"
1010
dependencies = [
1111
"e3nn>=0.5.5",
12+
"jaxtyping>=0.2.38",
1213
"torch>=2.4.1",
1314
]
1415

1516
[build-system]
1617
requires = ["hatchling"]
1718
build-backend = "hatchling.build"
19+
20+
[dependency-groups]
21+
dev = [
22+
"pytest>=8.3.4",
23+
]
24+
25+
[tool.ruff.lint]
26+
ignore = ["F722"]

src/e3tools/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from ._scatter import scatter
2+
from ._radius import radius, radius_graph
3+
4+
__all__ = ["scatter", "radius", "radius_graph"]

src/e3tools/_radius.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import math
2+
3+
import torch
4+
from torch import Tensor
5+
6+
from jaxtyping import Float, Int64
7+
8+
9+
# ref https://github.com/rusty1s/pytorch_cluster/blob/master/torch_cluster/radius.py
10+
def radius(
11+
x: Float[Tensor, "N D"],
12+
y: Float[Tensor, "M D"],
13+
r: float,
14+
batch_x: Int64[Tensor, " N"] | None = None,
15+
batch_y: Int64[Tensor, " M"] | None = None,
16+
ignore_same_index: bool = True,
17+
chunk_size: int | None = None,
18+
) -> Int64[Tensor, "2 E"]:
19+
"""For each element in `y` find all points in `x` within distance `r`"""
20+
N, _ = x.shape
21+
M, _ = y.shape
22+
23+
if chunk_size is None:
24+
chunk_size = N + 1
25+
26+
if batch_x is None:
27+
batch_x = torch.zeros(N, dtype=torch.int64, device=x.device)
28+
29+
if batch_y is None:
30+
batch_y = torch.zeros(N, dtype=torch.int64, device=x.device)
31+
32+
if x.numel() == 0 or y.numel() == 0:
33+
return torch.empty(2, 0, dtype=torch.long, device=x.device)
34+
35+
x = x.view(-1, 1) if x.dim() == 1 else x
36+
y = y.view(-1, 1) if y.dim() == 1 else y
37+
x, y = x.contiguous(), y.contiguous()
38+
39+
batch_size = int(batch_x.max()) + 1
40+
batch_size = max(batch_size, int(batch_y.max()) + 1)
41+
assert batch_size > 0
42+
43+
r2 = torch.as_tensor(r * r, dtype=x.dtype, device=x.device)
44+
45+
n_chunks = math.ceil(N / chunk_size)
46+
47+
rows = []
48+
cols = []
49+
50+
for y_chunk, batch_y_chunk, index_y_chunk in zip(
51+
torch.chunk(y, n_chunks),
52+
torch.chunk(batch_y, n_chunks),
53+
torch.chunk(torch.arange(M, device=x.device), n_chunks),
54+
):
55+
pdist = (x[:, None] - y_chunk).pow(2).sum(dim=-1)
56+
same_batch = batch_x[:, None] == batch_y_chunk
57+
same_index = torch.arange(N, device=x.device)[:, None] == index_y_chunk
58+
59+
connected = (pdist <= r2) & same_batch
60+
if ignore_same_index:
61+
connected = connected & ~same_index
62+
63+
row, col = torch.nonzero(connected, as_tuple=True)
64+
cols.append(col + index_y_chunk[0])
65+
rows.append(row)
66+
67+
row = torch.cat(rows, dim=0)
68+
col = torch.cat(cols, dim=0)
69+
70+
return torch.stack((col, row), dim=0)
71+
72+
73+
def radius_graph(
74+
x: Float[Tensor, "N D"],
75+
r: float,
76+
batch: Int64[Tensor, " N"] | None = None,
77+
chunk_size: int | None = None,
78+
) -> Int64[Tensor, "2 E"]:
79+
return radius(x, x, r, batch, batch, ignore_same_index=True, chunk_size=chunk_size)

src/e3tools/_scatter.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
from torch import Tensor
3+
4+
5+
def broadcast(src: Tensor, other: Tensor, dim: int):
6+
if dim < 0:
7+
dim = other.dim() + dim
8+
if src.dim() == 1:
9+
for _ in range(0, dim):
10+
src = src.unsqueeze(0)
11+
for _ in range(src.dim(), other.dim()):
12+
src = src.unsqueeze(-1)
13+
src = src.expand(other.size())
14+
return src
15+
16+
17+
def scatter(src, index, dim, dim_size: int | None = None, reduce="sum"):
18+
in_shape = src.shape
19+
20+
if dim < 0:
21+
dim = src.dim() + dim
22+
23+
if dim_size is None:
24+
if index.numel() == 0:
25+
dim_size = 0
26+
else:
27+
dim_size = int(index.max()) + 1
28+
29+
index = broadcast(index, src, dim)
30+
31+
assert src.ndim == index.ndim, f"{src.ndim=}, {index.ndim=}"
32+
33+
out_shape = (*in_shape[:dim], dim_size, *in_shape[dim + 1 :])
34+
out = torch.zeros(*out_shape, dtype=src.dtype, device=src.device)
35+
36+
assert out.ndim == index.ndim, (
37+
f"{out.ndim=}, {index.ndim=} {out_shape=}, {in_shape=}, {dim=}"
38+
)
39+
return torch.scatter_reduce(out, dim, index, src, reduce, include_self=False)

src/e3tools/nn/_conv.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import e3nn
55
import torch
66
from e3nn import o3
7-
from torch_scatter import scatter
7+
8+
from e3tools import scatter
89

910
from ._gate import Gated
1011
from ._interaction import LinearSelfInteraction

src/e3tools/nn/_transformer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import e3nn
55
import torch
66
from e3nn import o3
7-
from torch_scatter import scatter
7+
8+
from e3tools import scatter
89

910
from ._conv import Conv
1011
from ._interaction import LinearSelfInteraction

tests/test_basic.py

+207
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import functools
2+
3+
import e3nn
4+
import pytest
5+
import torch
6+
from e3nn import o3
7+
from e3tools.nn import (
8+
Attention,
9+
Conv,
10+
ConvBlock,
11+
EquivariantMLP,
12+
ExperimentalConv,
13+
Gated,
14+
LayerNorm,
15+
MultiheadAttention,
16+
SeparableConv,
17+
TransformerBlock,
18+
)
19+
from e3tools import radius_graph
20+
21+
torch.set_default_dtype(torch.float64)
22+
23+
CONV_LAYERS = [Conv, SeparableConv, ExperimentalConv]
24+
25+
26+
def apply_layer_rot(layer):
27+
N = 20
28+
edge_attr_dim = 10
29+
max_radius = 1.3
30+
31+
pos = torch.randn(N, 3)
32+
node_attr = layer.irreps_in.randn(N, -1)
33+
34+
edge_index = radius_graph(pos, max_radius)
35+
edge_vec = pos[edge_index[0]] - pos[edge_index[1]]
36+
edge_length = (edge_vec).norm(dim=1)
37+
edge_attr = e3nn.math.soft_one_hot_linspace(
38+
edge_length,
39+
start=0.0,
40+
end=max_radius,
41+
number=edge_attr_dim,
42+
basis="smooth_finite",
43+
cutoff=True,
44+
)
45+
46+
edge_sh = o3.spherical_harmonics(
47+
layer.irreps_sh, edge_vec, True, normalization="component"
48+
)
49+
50+
rot = o3.rand_matrix()
51+
52+
D_node_attr = layer.irreps_in.D_from_matrix(rot)
53+
D_edge_sh = layer.irreps_sh.D_from_matrix(rot)
54+
55+
D_out = layer.irreps_out.D_from_matrix(rot)
56+
57+
out_1 = layer(
58+
node_attr @ D_node_attr.T, edge_index, edge_attr, edge_sh @ D_edge_sh.T
59+
)
60+
out_2 = layer(node_attr, edge_index, edge_attr, edge_sh) @ D_out.T
61+
62+
return out_1, out_2
63+
64+
65+
@pytest.mark.parametrize("conv", CONV_LAYERS)
66+
def test_conv(conv):
67+
irreps_in = o3.Irreps("10x0e + 10x1o + 10x2e")
68+
irreps_sh = irreps_in.spherical_harmonics(2)
69+
edge_attr_dim = 10
70+
71+
layer = conv(irreps_in, irreps_in, irreps_sh, edge_attr_dim=edge_attr_dim)
72+
73+
out_1, out_2 = apply_layer_rot(layer)
74+
assert torch.allclose(out_1, out_2, atol=1e-10)
75+
76+
77+
@pytest.mark.parametrize("conv", CONV_LAYERS)
78+
def test_gated_conv(conv):
79+
irreps_in = o3.Irreps("10x0e + 10x1o + 10x2e")
80+
irreps_sh = irreps_in.spherical_harmonics(2)
81+
edge_attr_dim = 10
82+
83+
wrapped = functools.partial(conv, irreps_sh=irreps_sh, edge_attr_dim=edge_attr_dim)
84+
85+
layer = Gated(wrapped, irreps_in=irreps_in, irreps_out=irreps_in)
86+
87+
out_1, out_2 = apply_layer_rot(layer)
88+
assert torch.allclose(out_1, out_2, atol=1e-10)
89+
90+
91+
@pytest.mark.parametrize("conv", CONV_LAYERS)
92+
def test_conv_block(conv):
93+
irreps_in = o3.Irreps("10x0e + 10x1o + 10x2e")
94+
irreps_sh = irreps_in.spherical_harmonics(2)
95+
edge_attr_dim = 10
96+
97+
layer = ConvBlock(
98+
irreps_in=irreps_in,
99+
irreps_out=irreps_in,
100+
irreps_sh=irreps_sh,
101+
edge_attr_dim=edge_attr_dim,
102+
conv=conv,
103+
)
104+
105+
out_1, out_2 = apply_layer_rot(layer)
106+
assert torch.allclose(out_1, out_2, atol=1e-10)
107+
108+
109+
@pytest.mark.parametrize("conv", CONV_LAYERS)
110+
def test_attention(conv):
111+
irreps_in = o3.Irreps("10x0e + 10x1o + 10x2e")
112+
irreps_out = irreps_in
113+
irreps_sh = irreps_in.spherical_harmonics(2)
114+
irreps_key = irreps_in
115+
irreps_query = irreps_in
116+
edge_attr_dim = 10
117+
118+
layer = Attention(
119+
irreps_in,
120+
irreps_out,
121+
irreps_sh,
122+
irreps_query,
123+
irreps_key,
124+
edge_attr_dim,
125+
conv=conv,
126+
)
127+
128+
out_1, out_2 = apply_layer_rot(layer)
129+
assert torch.allclose(out_1, out_2, atol=1e-10)
130+
131+
132+
@pytest.mark.parametrize("conv", [Conv, SeparableConv])
133+
def test_multihead_attention(conv):
134+
irreps_in = o3.Irreps("10x0e + 10x1o + 10x2e")
135+
irreps_out = irreps_in
136+
irreps_sh = irreps_in.spherical_harmonics(2)
137+
irreps_key = irreps_in
138+
irreps_query = irreps_in
139+
edge_attr_dim = 10
140+
n_head = 2
141+
142+
layer = MultiheadAttention(
143+
irreps_in,
144+
irreps_out,
145+
irreps_sh,
146+
irreps_query,
147+
irreps_key,
148+
edge_attr_dim,
149+
n_head,
150+
conv=conv,
151+
)
152+
153+
out_1, out_2 = apply_layer_rot(layer)
154+
assert torch.allclose(out_1, out_2, atol=1e-10)
155+
156+
157+
def test_layer_norm():
158+
irreps = o3.Irreps("10x0e + 10x1o + 10x2e")
159+
160+
layer = LayerNorm(irreps)
161+
rot = o3.rand_matrix()
162+
D = irreps.D_from_matrix(rot)
163+
164+
x = irreps.randn(10, -1)
165+
166+
out_1 = layer(x @ D.T)
167+
out_2 = layer(x) @ D.T
168+
169+
assert torch.allclose(out_1, out_2, atol=1e-10)
170+
171+
172+
def test_equivariant_mlp():
173+
irreps = o3.Irreps("10x0e + 10x1o + 10x2e")
174+
irreps_hidden = o3.Irreps([(4 * mul, ir) for mul, ir in irreps])
175+
176+
layer = EquivariantMLP(
177+
irreps, irreps, [irreps_hidden, irreps_hidden], norm_layer=LayerNorm
178+
)
179+
180+
rot = o3.rand_matrix()
181+
D = irreps.D_from_matrix(rot)
182+
183+
x = irreps.randn(10, -1)
184+
185+
out_1 = layer(x @ D.T)
186+
out_2 = layer(x) @ D.T
187+
188+
assert torch.allclose(out_1, out_2, atol=1e-10)
189+
190+
191+
def test_transformer():
192+
irreps_in = o3.Irreps("10x0e + 10x1o + 10x2e")
193+
irreps_out = irreps_in
194+
irreps_sh = irreps_in.spherical_harmonics(2)
195+
edge_attr_dim = 10
196+
n_head = 2
197+
198+
layer = TransformerBlock(
199+
irreps_in=irreps_in,
200+
irreps_out=irreps_out,
201+
irreps_sh=irreps_sh,
202+
edge_attr_dim=edge_attr_dim,
203+
n_head=n_head,
204+
)
205+
206+
out_1, out_2 = apply_layer_rot(layer)
207+
assert torch.allclose(out_1, out_2, atol=1e-10)

0 commit comments

Comments
 (0)