Skip to content

Commit 5ab1205

Browse files
authored
rng hlops: add normal and kaiming_normal (tinygrad#1378)
* add normal and kaiming_normal * make sure its float * add tests
1 parent 37fa7e9 commit 5ab1205

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

test/test_randomness.py

+11
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def test_randn(self):
6767
self.assertTrue(normal_test(Tensor.randn))
6868
self.assertTrue(equal_distribution(Tensor.randn, torch.randn, lambda x: np.random.randn(*x)))
6969

70+
def test_normal(self):
71+
self.assertTrue(normal_test(Tensor.normal))
72+
self.assertTrue(equal_distribution(Tensor.normal, lambda x: torch.nn.init.normal_(torch.empty(x), mean=0, std=1), lambda x: np.random.normal(loc=0, scale=1, size=x)))
73+
7074
def test_uniform(self):
7175
self.assertFalse(normal_test(Tensor.uniform))
7276
self.assertTrue(equal_distribution(Tensor.uniform, lambda x: torch.nn.init.uniform_(torch.empty(x), a=-1, b=1), lambda x: np.random.uniform(low=-1, high=1, size=x)))
@@ -86,6 +90,13 @@ def test_kaiming_uniform(self):
8690
for shape in [(128, 64, 3, 3), (20, 24)]:
8791
self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape))
8892

93+
def test_kaiming_normal(self):
94+
Tensor.manual_seed(1337)
95+
torch.manual_seed(1337)
96+
np.random.seed(1337)
97+
for shape in [(128, 64, 3, 3), (20, 24)]:
98+
self.assertTrue(equal_distribution(Tensor.kaiming_normal, lambda x: torch.nn.init.kaiming_normal_(torch.empty(x)), shape=shape))
99+
89100
def test_conv2d_init(self):
90101
params = (128, 256, (3,3))
91102
assert equal_distribution(lambda *_: nn.Conv2d(*params).weight, lambda _: torch.nn.Conv2d(*params).weight.detach())

test/test_tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def test_gradcheck(self):
140140
self.assertFalse(gradcheck(tiny_func, tiny_x, eps = 1e-5))
141141

142142
def test_random_fns_are_deterministic_with_seed(self):
143-
for random_fn in [Tensor.randn, Tensor.uniform, Tensor.scaled_uniform, Tensor.glorot_uniform]:
143+
for random_fn in [Tensor.randn, Tensor.normal, Tensor.uniform, Tensor.scaled_uniform, Tensor.glorot_uniform, Tensor.kaiming_normal]:
144144
with self.subTest(msg=f"Tensor.{random_fn.__name__}"):
145145
Tensor.manual_seed(1337)
146146
a = random_fn(10,10).realize()

tinygrad/tensor.py

+9
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
178178
src = Tensor.rand(2, *shape, **kwargs)
179179
return src[0].mul(2*pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
180180

181+
@staticmethod
182+
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
183+
181184
@staticmethod
182185
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low
183186

@@ -194,6 +197,12 @@ def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
194197
bound = sqrt(3.0) * sqrt(2.0 / (1 + a ** 2)) / sqrt(prod(shape[1:]))
195198
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
196199

200+
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
201+
@staticmethod
202+
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
203+
std = sqrt(2.0 / (1 + a ** 2)) / sqrt(prod(shape[1:]))
204+
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
205+
197206
# ***** toposort and backward pass *****
198207
def deepwalk(self):
199208
def _deepwalk(node, visited, nodes):

0 commit comments

Comments
 (0)