|
1 |
| -from typing import Final, Dict, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Union |
2 |
| -import math, collections |
| 1 | +from typing import Dict, ClassVar, List, Optional, NamedTuple, Tuple, Union |
| 2 | +import math |
3 | 3 | from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, MemOp, ConstOp
|
4 | 4 | from tinygrad.ops import ASTRunner, UnaryOps, BinaryOps, TernaryOps
|
5 |
| -from tinygrad.helpers import ImageDType, dtypes, colored, getenv, prod, DType |
| 5 | +from tinygrad.helpers import ImageDType, dtypes, getenv, prod, DType |
6 | 6 | from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable
|
7 | 7 |
|
8 | 8 | # div is different in cl than python
|
@@ -75,11 +75,11 @@ def render_for(self, expr: str, _min:int, _max:int) -> str:
|
75 | 75 | def render_conditional(self, cond: str, x:str, y:str) -> str:
|
76 | 76 | return f"({cond})?({x}):{y}"
|
77 | 77 |
|
78 |
| - def render_kernel(self, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]: |
| 78 | + def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]: |
79 | 79 | tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else ""
|
80 | 80 | buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
|
81 | 81 | ("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
|
82 |
| - prg = ''.join([f"{self.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] + |
| 82 | + prg = ''.join([f"{self.kernel_prefix} void {function_name}(",] + |
83 | 83 | [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
84 | 84 | [") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
85 | 85 | if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg])
|
@@ -110,7 +110,7 @@ def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:Li
|
110 | 110 | local_size.append(var.max+1)
|
111 | 111 | return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
|
112 | 112 |
|
113 |
| -def uops_to_cstyle(uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]: |
| 113 | +def uops_to_cstyle(function_name:str, uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]: |
114 | 114 | global_size: List[int] = []
|
115 | 115 | local_size: List[int] = []
|
116 | 116 | kernel,prekernel = [],[]
|
@@ -182,32 +182,18 @@ def kk(s): kernel.append(" "*depth+s)
|
182 | 182 | else:
|
183 | 183 | raise RuntimeError(f"failed to render {uop}")
|
184 | 184 |
|
185 |
| - return lang.render_kernel(kernel, bufs, global_size, local_size, prekernel) |
| 185 | + return lang.render_kernel(function_name, kernel, bufs, global_size, local_size, prekernel) |
186 | 186 |
|
187 | 187 | class CStyleCodegen(Linearizer):
|
188 | 188 | lang: ClassVar[CStyleLanguage] = CStyleLanguage()
|
189 | 189 | supports_constant_folding: bool = True
|
190 | 190 | supports_float4: bool = True
|
191 | 191 | supports_float4_alu: bool = True
|
192 | 192 |
|
193 |
| - # for renaming |
194 |
| - kernel_cnt: Final[DefaultDict[str, int]] = collections.defaultdict(int) |
195 |
| - kernel_name_cache: Final[Dict[str, Tuple[str, str]]] = {} |
196 |
| - |
197 | 193 | def codegen(self):
|
198 | 194 | self.process()
|
199 | 195 | if self.lang.global_max: self.limit_global_dims(len(self.lang.gid), self.lang.global_max, self.lang.local_max) # NOTE: this is optional now
|
200 | 196 | self.linearize()
|
201 | 197 |
|
202 |
| - prg, global_size, local_size = uops_to_cstyle(self.uops, self.lang) |
203 |
| - |
204 |
| - # painfully name the function something unique |
205 |
| - if prg in CStyleCodegen.kernel_name_cache: function_name, display_name = CStyleCodegen.kernel_name_cache[prg] |
206 |
| - else: |
207 |
| - CStyleCodegen.kernel_cnt[self.function_name] += 1 |
208 |
| - suffix = f"{'n'+str(CStyleCodegen.kernel_cnt[self.function_name]-1)}" if CStyleCodegen.kernel_cnt[self.function_name] > 1 else "" |
209 |
| - CStyleCodegen.kernel_name_cache[prg] = function_name, display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK') |
210 |
| - |
211 |
| - return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), |
212 |
| - global_size, local_size, |
213 |
| - op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=display_name) |
| 198 | + return ASTRunner(self.function_name, *uops_to_cstyle(self.function_name, self.uops, self.lang), |
| 199 | + op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name) |
0 commit comments