forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_linearizer.py
23 lines (20 loc) · 875 Bytes
/
test_linearizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np
import unittest
from tinygrad.lazy import Device
from tinygrad.ops import GlobalCounters, Compiled
from tinygrad.tensor import Tensor
class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
if not isinstance(Device[Device.DEFAULT], Compiled):
self.skipTest("Only Compiled supports cache")
a, b = Tensor.randn(4), Tensor.randn(4)
np_a, np_b = a.numpy(), b.numpy()
GlobalCounters.cache = []
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))).realize()
rawbufs = GlobalCounters.cache[0][1]
GlobalCounters.cache = None
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.realized, b.lazydata.realized}
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
np.testing.assert_allclose(np_c, c.numpy())
if __name__ == '__main__':
unittest.main()