Skip to content

Commit bd7f4b1

Browse files
authored
move renamer to linearizer (tinygrad#1442)
* move renamer to linearizer * uops converter * Delete test_uops.py
1 parent 669b406 commit bd7f4b1

File tree

6 files changed

+28
-35
lines changed

6 files changed

+28
-35
lines changed

test/unit/test_shm_tensor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import unittest
22
import multiprocessing.shared_memory as shared_memory
3-
from tinygrad.helpers import CI
3+
from tinygrad.helpers import CI, OSX
44
from tinygrad.runtime.ops_shm import RawShmBuffer
55
from tinygrad.tensor import Tensor, Device
66
import numpy as np
77

8+
@unittest.skipIf(OSX, "no shm on OSX")
89
class TestRawShmBuffer(unittest.TestCase):
910
def test_e2e(self):
1011
t = Tensor.randn(2, 2, 2).realize()

tinygrad/codegen/cstyle.py

+9-23
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import Final, Dict, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Union
2-
import math, collections
1+
from typing import Dict, ClassVar, List, Optional, NamedTuple, Tuple, Union
2+
import math
33
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, MemOp, ConstOp
44
from tinygrad.ops import ASTRunner, UnaryOps, BinaryOps, TernaryOps
5-
from tinygrad.helpers import ImageDType, dtypes, colored, getenv, prod, DType
5+
from tinygrad.helpers import ImageDType, dtypes, getenv, prod, DType
66
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable
77

88
# div is different in cl than python
@@ -75,11 +75,11 @@ def render_for(self, expr: str, _min:int, _max:int) -> str:
7575
def render_conditional(self, cond: str, x:str, y:str) -> str:
7676
return f"({cond})?({x}):{y}"
7777

78-
def render_kernel(self, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]:
78+
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]:
7979
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else ""
8080
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
8181
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
82-
prg = ''.join([f"{self.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
82+
prg = ''.join([f"{self.kernel_prefix} void {function_name}(",] +
8383
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
8484
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
8585
if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg])
@@ -110,7 +110,7 @@ def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:Li
110110
local_size.append(var.max+1)
111111
return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
112112

113-
def uops_to_cstyle(uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
113+
def uops_to_cstyle(function_name:str, uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
114114
global_size: List[int] = []
115115
local_size: List[int] = []
116116
kernel,prekernel = [],[]
@@ -182,32 +182,18 @@ def kk(s): kernel.append(" "*depth+s)
182182
else:
183183
raise RuntimeError(f"failed to render {uop}")
184184

185-
return lang.render_kernel(kernel, bufs, global_size, local_size, prekernel)
185+
return lang.render_kernel(function_name, kernel, bufs, global_size, local_size, prekernel)
186186

187187
class CStyleCodegen(Linearizer):
188188
lang: ClassVar[CStyleLanguage] = CStyleLanguage()
189189
supports_constant_folding: bool = True
190190
supports_float4: bool = True
191191
supports_float4_alu: bool = True
192192

193-
# for renaming
194-
kernel_cnt: Final[DefaultDict[str, int]] = collections.defaultdict(int)
195-
kernel_name_cache: Final[Dict[str, Tuple[str, str]]] = {}
196-
197193
def codegen(self):
198194
self.process()
199195
if self.lang.global_max: self.limit_global_dims(len(self.lang.gid), self.lang.global_max, self.lang.local_max) # NOTE: this is optional now
200196
self.linearize()
201197

202-
prg, global_size, local_size = uops_to_cstyle(self.uops, self.lang)
203-
204-
# painfully name the function something unique
205-
if prg in CStyleCodegen.kernel_name_cache: function_name, display_name = CStyleCodegen.kernel_name_cache[prg]
206-
else:
207-
CStyleCodegen.kernel_cnt[self.function_name] += 1
208-
suffix = f"{'n'+str(CStyleCodegen.kernel_cnt[self.function_name]-1)}" if CStyleCodegen.kernel_cnt[self.function_name] > 1 else ""
209-
CStyleCodegen.kernel_name_cache[prg] = function_name, display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')
210-
211-
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name),
212-
global_size, local_size,
213-
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=display_name)
198+
return ASTRunner(self.function_name, *uops_to_cstyle(self.function_name, self.uops, self.lang),
199+
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name)

tinygrad/codegen/linearizer.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, TypeVar, Dict, Iterator, Union, Sequence
1+
from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, TypeVar, Dict, Iterator, Union, Sequence, Final
22
import itertools, math
33
from collections import defaultdict
44
from enum import Enum, auto
@@ -281,6 +281,7 @@ def global_store(self, i, idxs:List[VariableOrNum], store:List[Token], ssa) -> N
281281
if isinstance(self.bufs[i].dtype, ImageDType): idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
282282
self.uop(UOps.STORE, None, [var], MemOp(self.get_buffer_name(i), idx, self.bufs[i].__class__ is LocalBuffer, self.bufs[i].dtype, valid))
283283

284+
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
284285
def linearize(self):
285286
# uops
286287
self.uops: List[UOp] = []
@@ -453,6 +454,11 @@ def ssa(name, ltype=dtypes.float) -> Token:
453454
# end the global loop
454455
self.uop(UOps.ENDLOOP, None, [], (global_idxs, "global"))
455456

457+
# name the function something unique
458+
Linearizer.kernel_cnt[self.function_name] += 1
459+
suffix = f"{'n'+str(Linearizer.kernel_cnt[self.function_name]-1)}" if Linearizer.kernel_cnt[self.function_name] > 1 else ""
460+
self.function_name, self.display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')
461+
456462
_OT = TypeVar("_OT")
457463
def uop(self, uop:UOps, out:_OT, vin:List[Token], arg:Any=None) -> _OT:
458464
self.uops.append(UOp(uop, cast(Optional[Token], out), vin, arg))
@@ -616,7 +622,7 @@ def limit_global_dims(self, limit: int, global_max: List[int], local_max: List[i
616622
num_to_merge = ((self.first_reduce-self.local_dims) - limit)+1
617623
self.reshape_and_permute(lambda x: (prod(x[0:num_to_merge]),)+x[num_to_merge:], None)
618624
if DEBUG >= 3: print("reshaped to", self.full_shape, "due to too many global dimensions")
619-
# Check the global allocation limit, current the global_size will be flipped during codegen
625+
# Check the global allocation limit, current the global_size will be flipped during codegen
620626
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
621627
global_dims = self.first_reduce-self.local_dims
622628
if global_dims > 0:
@@ -627,7 +633,7 @@ def limit_global_dims(self, limit: int, global_max: List[int], local_max: List[i
627633
for i in range(global_dims-1):
628634
if self.full_shape[i] > global_max[i]:
629635
order = list(range(len(self.full_shape)))
630-
order[i], order[global_dims-1] = order[global_dims-1], order[i]
636+
order[i], order[global_dims-1] = order[global_dims-1], order[i]
631637
self.reshape_and_permute(None, order)
632638
if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")
633639

tinygrad/codegen/llvmir.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Final, Dict, Callable, Any, List, Optional
1+
from typing import Final, Dict, Callable, Any, List, Optional, Tuple
22
import functools
33
from llvmlite import ir # type: ignore
44
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, Token, MemOp, ConstOp
@@ -32,7 +32,7 @@ def int_const(x): return ir.Constant(ir.IntType(64), x)
3232
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)),
3333
}
3434

35-
def uops_to_llvm_ir(uops:List[UOp]) -> str:
35+
def uops_to_llvm_ir(uops:List[UOp]) -> Tuple[str, Optional[List[int]], Optional[List[int]]]:
3636
# all llvm stuff goes into a module
3737
module = ir.Module(name=__file__)
3838

@@ -131,11 +131,11 @@ def uops_to_llvm_ir(uops:List[UOp]) -> str:
131131
lvars[newvar] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
132132

133133
bb[-1].ret_void()
134-
return str(module)
134+
return str(module), None, None
135135

136136
class LLVMIRCodegen(Linearizer):
137137
def codegen(self):
138138
self.process()
139139
# no optimize, this doesn't support local
140140
self.linearize()
141-
return ASTRunner('exec', uops_to_llvm_ir(self.uops), op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name)
141+
return ASTRunner('exec', *uops_to_llvm_ir(self.uops), op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name)

tinygrad/codegen/wgsl.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ def render_const(self, x:Union[float,int], var_dtype) -> str:
3232
else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f")
3333
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
3434

35-
def render_kernel(self, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str, List[int], List[int]]:
35+
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str, List[int], List[int]]:
3636
local_size = local_size[::-1] if len(local_size) else [1]
3737
bind_it = iter(range(len(bufs)))
3838
prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
3939
prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) var<storage,read_write> {name}: array<{type_map[dtype]}>;" for name,dtype in bufs])
40-
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn KERNEL_NAME_PLACEHOLDER(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
40+
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
4141
return prg, global_size[::-1] if len(global_size) else [1], local_size
4242

4343
def render_for(self, expr:str, _min:int, _max:int) -> str:

tinygrad/runtime/ops_shm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os, mmap
2-
try: import _posixshmem # not available on windows
2+
try: import _posixshmem # type: ignore
33
except Exception: pass
44
from typing import Callable, Dict
55
from tinygrad.helpers import DType
@@ -16,7 +16,7 @@ def __init__(self, size, dtype:DType, device:str):
1616
fd = _posixshmem.shm_open(device, os.O_RDWR, 0o600)
1717
# TODO: these flags are somewhat platform specific, but python doesn't expose the ones we need
1818
shm = mmap.mmap(fd, size * dtype.itemsize, flags=mmap.MAP_SHARED | 0x2000 | 0x008000)
19-
shm.madvise(mmap.MADV_HUGEPAGE)
19+
shm.madvise(mmap.MADV_HUGEPAGE) # type: ignore
2020
os.close(fd)
2121
if self.cache_id is not None: SHM_CACHE[self.cache_id] = shm
2222

0 commit comments

Comments
 (0)