|
2 | 2 | import unittest
|
3 | 3 | import numpy as np
|
4 | 4 | from tinygrad.tensor import Tensor, Device
|
5 |
| -from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm |
| 5 | +from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm, LayerNorm2d |
6 | 6 | import torch
|
7 | 7 |
|
8 | 8 | class TestNN(unittest.TestCase):
|
@@ -76,7 +76,7 @@ def _test_linear(x):
|
76 | 76 | def test_conv2d(self):
|
77 | 77 | BS, C1, H, W = 4, 16, 224, 224
|
78 | 78 | C2, K, S, P = 64, 7, 2, 1
|
79 |
| - |
| 79 | + |
80 | 80 | # create in tinygrad
|
81 | 81 | layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
82 | 82 |
|
@@ -131,5 +131,24 @@ def test_layernorm(self):
|
131 | 131 | torch_z = torch_layer(torch_x)
|
132 | 132 | np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
133 | 133 |
|
| 134 | + def test_layernorm_2d(self): |
| 135 | + N, C, H, W = 20, 5, 10, 10 |
| 136 | + |
| 137 | + # create in tinygrad |
| 138 | + layer = LayerNorm2d(C) |
| 139 | + |
| 140 | + # create in torch |
| 141 | + with torch.no_grad(): |
| 142 | + torch_layer = torch.nn.LayerNorm([C]).eval() |
| 143 | + torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) |
| 144 | + torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) |
| 145 | + |
| 146 | + # test |
| 147 | + x = Tensor.randn(N, C, H, W) |
| 148 | + z = layer(x) |
| 149 | + torch_x = torch.tensor(x.cpu().numpy()) |
| 150 | + torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2) |
| 151 | + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) |
| 152 | + |
134 | 153 | if __name__ == '__main__':
|
135 | 154 | unittest.main()
|
0 commit comments