Skip to content

Commit 1a03930

Browse files
authored
good changes from llama branch (tinygrad#671)
* good changes from llama * transpose behavior changed
1 parent de1b6d3 commit 1a03930

13 files changed

+76
-55
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ vertex.bin
1616
recognize*
1717
.idea
1818
disassemblers/applegpu
19+
*.prof

extra/gemm/metal_matmul.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
import numpy as np
2-
from tinygrad.runtime.ops_metal import CLBuffer, CLProgram
3-
4-
def benchmark(prog):
5-
e = prog()
6-
e.waitUntilCompleted()
7-
return (e.GPUEndTime() - e.GPUStartTime())*1e9
8-
def mb(prog, N=10): return min([benchmark(prog) for _ in range(N)])
2+
from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram
93

104
N = 2048
11-
a = CLBuffer(N*N*4)
12-
b = CLBuffer(N*N*4)
13-
c = CLBuffer(N*N*4)
5+
a = RawMetalBuffer(N*N*4)
6+
b = RawMetalBuffer(N*N*4)
7+
c = RawMetalBuffer(N*N*4)
148

159
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
1610
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
@@ -23,7 +17,7 @@ def mb(prog, N=10): return min([benchmark(prog) for _ in range(N)])
2317

2418
FLOPS = N*N*N*2
2519

26-
prog = CLProgram("test", f"""
20+
prog = MetalProgram("test", f"""
2721
#include <metal_stdlib>
2822
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
2923
using namespace metal;
@@ -92,12 +86,12 @@ def mb(prog, N=10): return min([benchmark(prog) for _ in range(N)])
9286
}}
9387
}}
9488
}}""")
95-
tm = mb(lambda: prog([N*N//(2*4*4)], [4*32], a._cl, b._cl, c._cl))
89+
tm = min([prog([N*N//(2*4*4)], [4*32], a, b, c, wait=True) for _ in range(10)])
9690
na = a.toCPU().reshape(N,N)
9791
comp = nb@nc
9892
if N <= 32:
9993
print(na)
10094
print(comp)
101-
print(f"{N*N:10d} {tm*1e-3:9.2f} us, would be {FLOPS/tm:.2f} GFLOPS matmul")
95+
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:.2f} GFLOPS matmul")
10296
np.testing.assert_allclose(na, comp, atol=1e-3)
10397

models/transformer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ def attn(self, x):
2626
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \
2727
for y in [self.query, self.key, self.value]]
2828

29-
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
30-
key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, time)
31-
value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
29+
query = query.permute(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
30+
key = key.permute(order=(0,2,3,1)) # (bs, num_heads, head_size, time)
31+
value = value.permute(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
3232

3333
score = query.dot(key) * (1 / np.sqrt(self.head_size))
3434
weights = score.softmax() # (bs, num_heads, time, time)
35-
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, time, num_heads, head_size)
35+
attention = weights.dot(value).permute(order=(0,2,1,3)) # (bs, time, num_heads, head_size)
3636

3737
return attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)).linear(*self.out)
3838

models/vit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, layers=12, embed_dim=192, num_heads=3):
1717

1818
def patch_embed(self, x):
1919
x = x.conv2d(*self.embedding, stride=16)
20-
x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).transpose(order=(0,2,1))
20+
x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).permute(order=(0,2,1))
2121
return x
2222

2323
def forward(self, x):

test/test_ops.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
2626
tinygrad_fp = time.monotonic() - st
2727

2828
def compare(s, x,y,atol,rtol):
29-
if y.shape != tuple(): assert x.shape == y.shape, f"shape mismatch {x.shape} != {y.shape}"
29+
if y.shape != tuple(): assert x.shape == y.shape, f"shape mismatch (tinygrad){x.shape} != (torch){y.shape}"
3030
try:
3131
np.testing.assert_allclose(x,y, atol=atol, rtol=rtol)
3232
except Exception:
@@ -255,10 +255,10 @@ def test_pad2d(self):
255255
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)))
256256

257257
def test_transpose(self):
258-
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1)))
259-
helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(order=(2,1,0)))
260-
helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.transpose(order=(3,0,2,1)))
261-
helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.transpose(order=(3,2,1,0)))
258+
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(1,2))
259+
helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(0,2))
260+
helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.permute(order=(3,0,2,1)))
261+
helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.permute(order=(3,2,1,0)))
262262

263263
def test_reshape(self):
264264
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)))

tinygrad/image.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,27 @@
1-
from tinygrad.helpers import IMAGE
1+
from tinygrad.helpers import IMAGE, prod
22
from tinygrad.lazy import get_single_root
33

4+
def image_dot_decorator(normal_dot):
5+
if IMAGE == 0: return normal_dot
6+
def image_dot(self, w):
7+
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
8+
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
9+
cin, cout = w.shape[-2], w.shape[-1]
10+
out_shape_t = self.shape[0:-2] + (cout,-1)
11+
if len(self.shape) > 1:
12+
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
13+
else:
14+
order, out_shape_t = (0,), (cout, )
15+
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
16+
17+
# NOTE: with NHWC we can remove the transposes
18+
# bs x groups*cin x H x W
19+
cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
20+
# groups*cout x cin x H, W
21+
cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
22+
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).permute(order=order)
23+
return image_dot
24+
425
def image_conv2d_decorator(normal_conv):
526
if IMAGE == 0: return normal_conv
627

tinygrad/lazy.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def get_buffer(name, base='tinygrad.runtime'):
2121

2222
class _Device:
2323
def __init__(self) -> None:
24+
# TODO: make this dynamic to when you try to access the _buffers
2425
self._buffers : Dict[str, Type[DeviceBuffer]] = {x.upper():get_buffer(x) for x in
2526
[os.path.splitext(x)[0][len("ops_"):] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "runtime"))) if x.startswith("ops_")] if x is not None}
2627
self.DEFAULT : str = "CPU"

tinygrad/ops.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,16 @@ def map_buffers(real_srcs, x:LazyOp) -> LazyOp:
3333
return LazyOp(x.op, tuple((map_buffers(real_srcs, y) if isinstance(y, LazyOp) else real_srcs[y]) for y in x.src), x.arg)
3434

3535
_T = TypeVar("_T")
36-
class RawBuffer:
37-
size : int
38-
def __init__(self, size): raise NotImplementedError("must be implemented")
36+
class Copyable:
3937
@classmethod
4038
def fromCPU(cls:Type[_T], x:np.ndarray) -> _T: raise NotImplementedError("must be implemented")
41-
def toCPU(self:RawBuffer) -> np.ndarray: raise NotImplementedError("must be implemented")
39+
def toCPU(self:Copyable) -> np.ndarray: raise NotImplementedError("must be implemented")
40+
41+
class RawBuffer(Copyable): # pylint: disable=abstract-method
42+
def __init__(self, size:int):
43+
self.size : int = size
44+
GlobalCounters.mem_used += self.size
45+
def __del__(self): GlobalCounters.mem_used -= self.size
4246

4347
class RawBufferCopyIn(RawBuffer):
4448
def copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
@@ -58,7 +62,7 @@ def toCPU(self) -> np.ndarray:
5862
return x
5963

6064
# a placeholder class to extend by the exec classes
61-
class DeviceBuffer(RawBuffer):
65+
class DeviceBuffer(Copyable):
6266
_buf: Any # underlying buffer
6367
shape: Tuple[int, ...]
6468
@classmethod

tinygrad/runtime/ops_clang.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
77

88
class RawMallocBuffer(RawBufferCopyIn):
9-
def __init__(self, size): self.size, self._buf = size, (ctypes.c_float * (size//4))()
9+
def __init__(self, size):
10+
super().__init__(size)
11+
self._buf = (ctypes.c_float * (size//4))()
1012
def copyin(self, x:np.ndarray): ctypes.memmove(self._buf, x.ctypes.data, x.size*4)
1113
def toCPU(self): return np.ctypeslib.as_array(self._buf)
1214

tinygrad/runtime/ops_cuda.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
99

1010
class RawCUDABuffer(RawBufferCopyInOut):
11-
def __init__(self, size): self.size, self._cl = size, cuda.mem_alloc(size)
11+
def __init__(self, size):
12+
super().__init__(size)
13+
self._cl = cuda.mem_alloc(size)
1214
def copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._cl, x, stream)
1315
def copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._cl)
1416

tinygrad/runtime/ops_gpu.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class CLBuffer(RawBufferCopyInOut):
3030
# TODO: this can be in RawBuffer generically
3131
BUFFER_CACHE : ClassVar[Dict[int, List[cl.Buffer]]] = defaultdict(list)
3232

33-
def __init__(self, size):
33+
def __init__(self, size): # pylint: disable=super-init-not-called
3434
self.size = size
3535
if len(CLBuffer.BUFFER_CACHE[size]) > 0:
3636
self._cl = CLBuffer.BUFFER_CACHE[size].pop()
@@ -50,7 +50,7 @@ class CLImage(RawBuffer): # pylint: disable=abstract-method
5050
fmt : Final = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
5151
IMAGE : Final = True
5252

53-
def __init__(self, shape):
53+
def __init__(self, shape): # pylint: disable=super-init-not-called
5454
self.size, self._cl = shape, cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=(shape[1], shape[0]))
5555
GlobalCounters.mem_used += self._cl.row_pitch * self._cl.height
5656

tinygrad/runtime/ops_metal.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,14 @@ def mtl_queue(self):
2020
METAL = _METAL()
2121

2222
class RawMetalBuffer(RawBufferCopyIn):
23-
def __init__(self, size): self.size, self._cl = size, METAL.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
24-
def __del__(self): self._cl.release()
25-
def _as_np(self): return np.frombuffer(self._cl.contents().as_buffer(self._cl.length()), dtype=np.float32)
23+
def __init__(self, size):
24+
super().__init__(size)
25+
self._cl = METAL.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
26+
def __del__(self):
27+
self._cl.release()
28+
super().__del__()
29+
def _buffer(self): return self._cl.contents().as_buffer(self._cl.length())
30+
def _as_np(self, dtype=np.float32): return np.frombuffer(self._buffer(), dtype=dtype)
2631
def copyin(self, x:np.ndarray): np.copyto(self._as_np(), x.reshape(-1).data)
2732
def toCPU(self) -> np.ndarray:
2833
for cbuf in METAL.mtl_buffers_in_flight: cbuf.waitUntilCompleted()

tinygrad/tensor.py

+10-19
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
66
from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG, flatten
77
from tinygrad.lazy import Device, LazyBuffer, LazyNumpyArray
8-
from tinygrad.image import image_conv2d_decorator
8+
from tinygrad.image import image_conv2d_decorator, image_dot_decorator
99

1010
# An instantiation of the Function is the Context
1111
class Function:
@@ -252,8 +252,10 @@ def unsqueeze(self, dim):
252252

253253
# (padding_left, padding_right, padding_top, padding_bottom)
254254
def pad2d(self, padding:Tuple[int, ...]): return self.slice(((0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1])))
255-
# TODO: this is totally not transpose
256-
def transpose(self, order=(1,0)) -> Tensor: return self.permute(order=order)
255+
def transpose(self, ax1=1, ax2=0) -> Tensor:
256+
order = list(range(len(self.shape)))
257+
order[ax1], order[ax2] = order[ax2], order[ax1]
258+
return self.permute(order)
257259
def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
258260

259261
# ***** reduce ops *****
@@ -335,23 +337,11 @@ def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1,
335337
ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1)).reshape(bs, cout, oy, ox)
336338
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
337339

340+
@image_dot_decorator
338341
def dot(self, w:Tensor) -> Tensor:
339-
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
340-
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
341-
cin, cout = w.shape[-2], w.shape[-1]
342-
out_shape_t = self.shape[0:-2] + (cout,-1)
343-
if len(self.shape) > 1:
344-
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
345-
else:
346-
order, out_shape_t = (0,), (cout, )
347-
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
348-
349-
# NOTE: with NHWC we can remove the transposes
350-
# bs x groups*cin x H x W
351-
cx = self.transpose(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
352-
# groups*cout x cin x H, W
353-
cw = w.transpose(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
354-
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).transpose(order=order)
342+
x = self.reshape(*self.shape[0:-1], 1, self.shape[-1])
343+
w = w.reshape(*w.shape[0:-2], 1, w.shape[-2], w.shape[-1]).transpose(-1, -2)
344+
return (x*w).sum(-1).reshape(*x.shape[0:-2], -1)
355345

356346
# ***** mlops (unary) *****
357347

@@ -363,6 +353,7 @@ def exp(self): return mlops.Exp.apply(self)
363353

364354
def __neg__(self): return 0.0-self
365355
def sqrt(self): return self.pow(0.5)
356+
def rsqrt(self): return self.pow(-0.5)
366357
def square(self): return self*self
367358
def clip(self, min_, max_): return ((self-min_).relu()+min_) - (self-max_).relu()
368359
def abs(self): return self.relu() + (-self).relu()

0 commit comments

Comments
 (0)