Skip to content

Commit d6f4219

Browse files
committed
LayerNorm2d for 2 lines
1 parent 128ca16 commit d6f4219

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

models/convnext.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from tinygrad.tensor import Tensor
2-
from tinygrad.nn import Conv2d, LayerNorm, Linear
2+
from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear
33

44
class Block:
55
def __init__(self, dim):
@@ -18,8 +18,8 @@ def __call__(self, x:Tensor):
1818
class ConvNeXt:
1919
def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]):
2020
self.downsample_layers = [
21-
[Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm((dims[0], 1, 1), eps=1e-6)],
22-
*[[LayerNorm((dims[i], 1, 1), eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
21+
[Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm2d(dims[0], eps=1e-6)],
22+
*[[LayerNorm2d(dims[i], eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
2323
]
2424
self.stages = [[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))]
2525
self.norm = LayerNorm(dims[-1])

test/test_nn.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import unittest
33
import numpy as np
44
from tinygrad.tensor import Tensor, Device
5-
from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm
5+
from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm, LayerNorm2d
66
import torch
77

88
class TestNN(unittest.TestCase):
@@ -76,7 +76,7 @@ def _test_linear(x):
7676
def test_conv2d(self):
7777
BS, C1, H, W = 4, 16, 224, 224
7878
C2, K, S, P = 64, 7, 2, 1
79-
79+
8080
# create in tinygrad
8181
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
8282

@@ -131,5 +131,24 @@ def test_layernorm(self):
131131
torch_z = torch_layer(torch_x)
132132
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
133133

134+
def test_layernorm_2d(self):
135+
N, C, H, W = 20, 5, 10, 10
136+
137+
# create in tinygrad
138+
layer = LayerNorm2d(C)
139+
140+
# create in torch
141+
with torch.no_grad():
142+
torch_layer = torch.nn.LayerNorm([C]).eval()
143+
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
144+
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
145+
146+
# test
147+
x = Tensor.randn(N, C, H, W)
148+
z = layer(x)
149+
torch_x = torch.tensor(x.cpu().numpy())
150+
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
151+
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
152+
134153
if __name__ == '__main__':
135154
unittest.main()

tinygrad/nn/__init__.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,15 @@ def __call__(self, x:Tensor):
7373

7474
class LayerNorm:
7575
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
76-
normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
77-
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(normalized_shape))), eps, elementwise_affine
78-
self.weight, self.bias = (Tensor.ones(*normalized_shape), Tensor.zeros(*normalized_shape)) if elementwise_affine else (None, None)
76+
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
77+
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
78+
self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)
7979

8080
def __call__(self, x:Tensor):
81+
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
8182
x = x.layernorm(eps=self.eps, axis=self.axis)
8283
if not self.elementwise_affine: return x
8384
return x * self.weight + self.bias
85+
86+
class LayerNorm2d(LayerNorm):
87+
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

0 commit comments

Comments
 (0)