Skip to content

Commit 38fe84d

Browse files
authored
cleanup mlops (tinygrad#1521)
* cleanup mlops * that line belongs there
1 parent 47f18f4 commit 38fe84d

File tree

4 files changed

+53
-42
lines changed

4 files changed

+53
-42
lines changed

test/test_ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1143,9 +1143,9 @@ def test_gather(self):
11431143
self.helper_test_exception([], lambda: tor[tb,:,:,:,:].sum().backward(), lambda: ten.gather(ta, dim=0).sum().backward(), expected=(IndexError, RuntimeError)) # torch raises IndexError, Tensor raises RuntimeError
11441144

11451145
def test_scaled_product_attention(self):
1146-
helper_test_op([(32,8,128,64), (32,8,128,64), (32,8,128,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z))
1147-
helper_test_op([(32,8,128,64), (32,8,128,64), (32,8,128,64), (32,8,128,128)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m))
1148-
helper_test_op([(32,8,128,64), (32,8,128,64), (32,8,128,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True))
1146+
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z))
1147+
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m))
1148+
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True))
11491149

11501150
if __name__ == '__main__':
11511151
np.random.seed(1337)

test/test_uops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
6262
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b)
6363
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
6464
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b))
65-
# CMPLT and MOD aren't tested
65+
# MOD isn't tested
6666

6767
# doesn't work in LLVM
6868
#def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b, dtypes.int32)

tinygrad/lazy.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,21 @@ def cast(self:LazyBuffer, arg:Tuple[DType, bool]) -> LazyBuffer:
195195
assert not arg[1] or self.dtype.itemsize == arg[0].itemsize, "can't bitcast mismatched dtype itemsizes"
196196
return elementwise_op(UnaryOps.CAST, self, arg=arg) if self.dtype != arg[0] else self
197197
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
198-
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
199-
def ternary_op(self:LazyBuffer, op:TernaryOps, y: LazyBuffer, z:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y, z)
198+
def binary_op(self:LazyBuffer, op:BinaryOps, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(op, self, y)
199+
def ternary_op(self:LazyBuffer, op:TernaryOps, y:Union[LazyBuffer, float, int], z:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(op, self, y, z)
200+
201+
def __add__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.ADD, self, y)
202+
def __radd__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.ADD, y, self)
203+
def __mul__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.MUL, self, y)
204+
def __rmul__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.MUL, y, self)
205+
def __truediv__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.DIV, self, y)
206+
def __rtruediv__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.DIV, y, self)
207+
def __sub__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.SUB, self, y)
208+
def __rsub__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.SUB, y, self)
209+
def __lt__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.CMPLT, self, y)
210+
def __gt__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.CMPLT, y, self)
211+
def __neg__(self) -> LazyBuffer: return 0.0-self
212+
200213
def contiguous(self:LazyBuffer) -> LazyBuffer:
201214
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
202215
return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype)
@@ -304,7 +317,11 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
304317
new_srcs.append(x)
305318
return tuple(new_srcs)
306319

307-
def elementwise_op(op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
320+
def elementwise_op(op:Union[UnaryOps, BinaryOps, TernaryOps], *_srcs:Union[LazyBuffer, float, int], arg:Optional[Any]=None) -> LazyBuffer:
321+
# make them all LazyBuffers
322+
first_src = [x for x in _srcs if isinstance(x, LazyBuffer)][0]
323+
srcs:Tuple[LazyBuffer, ...] = tuple(x if isinstance(x, LazyBuffer) else first_src.const_like(x) for x in _srcs)
324+
308325
# if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops
309326
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)
310327

tinygrad/mlops.py

+29-35
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Tuple, Optional
2-
from tinygrad.helpers import argsort, ShapeType
2+
from tinygrad.helpers import argsort, ShapeType, DType
33
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
44
from tinygrad.tensor import Function
55
from tinygrad.lazy import LazyBuffer
@@ -11,50 +11,49 @@ def backward(self, grad_output): return grad_output
1111

1212
class Cast(Function):
1313
__slots__ = "input_dtype", "bitcast"
14-
def forward(self, x, dtype, bitcast=False):
14+
def forward(self, x:LazyBuffer, dtype:DType, bitcast=False):
1515
self.input_dtype, self.bitcast = x.dtype, bitcast
1616
return x.cast((dtype, bitcast))
17-
def backward(self, grad_output):
17+
def backward(self, grad_output:LazyBuffer):
1818
return grad_output.cast((self.input_dtype, self.bitcast))
1919

2020
# ************* unary ops *************
2121

2222
class Sin(Function):
2323
__slots__ = "x"
24-
def forward(self, x: LazyBuffer) -> LazyBuffer:
24+
def forward(self, x:LazyBuffer) -> LazyBuffer:
2525
self.x = x
2626
return x.unary_op(UnaryOps.SIN)
27-
def backward(self, grad: LazyBuffer) -> LazyBuffer:
28-
return self.x.const_like(math.pi / 2).binary_op(BinaryOps.SUB, self.x).unary_op(UnaryOps.SIN).binary_op(BinaryOps.MUL, grad)
27+
def backward(self, grad:LazyBuffer) -> LazyBuffer:
28+
return ((math.pi / 2) - self.x).unary_op(UnaryOps.SIN) * grad
2929

3030
# NOTE: maximum(x, 0) behaves differently where x=0
3131
class Relu(Function):
3232
__slots__ = "ret"
3333
def forward(self, x:LazyBuffer) -> LazyBuffer:
34-
self.ret = x.binary_op(BinaryOps.MAX, x.const_like(0))
34+
self.ret = x.binary_op(BinaryOps.MAX, 0)
3535
return self.ret
3636

3737
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
38-
mask = self.ret.const_like(0).binary_op(BinaryOps.CMPLT, self.ret)
39-
return mask.binary_op(BinaryOps.MUL, grad_output)
38+
return (0 < self.ret) * grad_output
4039

4140
class Log(Function):
4241
__slots__ = "x"
4342
def forward(self, x:LazyBuffer) -> LazyBuffer:
4443
self.x = x
45-
return x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, x.const_like(math.log(2)))
44+
return x.unary_op(UnaryOps.LOG2) * math.log(2)
4645

4746
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
48-
return grad_output.binary_op(BinaryOps.DIV, self.x)
47+
return grad_output / self.x
4948

5049
class Exp(Function):
5150
__slots__ = "ret"
5251
def forward(self, x:LazyBuffer) -> LazyBuffer:
53-
self.ret = x.binary_op(BinaryOps.MUL, x.const_like(1/math.log(2))).unary_op(UnaryOps.EXP2)
52+
self.ret = (x * (1/math.log(2))).unary_op(UnaryOps.EXP2)
5453
return self.ret
5554

5655
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
57-
return self.ret.binary_op(BinaryOps.MUL, grad_output)
56+
return self.ret * grad_output
5857

5958
class Sqrt(Function):
6059
__slots__ = "ret"
@@ -63,19 +62,19 @@ def forward(self, x:LazyBuffer) -> LazyBuffer:
6362
return self.ret
6463

6564
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
66-
return grad_output.binary_op(BinaryOps.DIV, self.ret.binary_op(BinaryOps.MUL, self.ret.const_like(2)))
65+
return grad_output / (self.ret * 2)
6766

6867
# NOTE: the implicit derivative of sigmoid is not stable
6968
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
7069
# TODO: have the backend automatically find this
7170
class Sigmoid(Function):
7271
__slots__ = "ret"
7372
def forward(self, x:LazyBuffer) -> LazyBuffer:
74-
self.ret = x.const_like(1).binary_op(BinaryOps.DIV, x.const_like(1).binary_op(BinaryOps.ADD, x.binary_op(BinaryOps.MUL, x.const_like(-1/math.log(2))).unary_op(UnaryOps.EXP2)))
73+
self.ret = 1 / (1 + (x * (-1/math.log(2))).unary_op(UnaryOps.EXP2))
7574
return self.ret
7675

7776
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
78-
return self.ret.binary_op(BinaryOps.MUL, self.ret.const_like(1).binary_op(BinaryOps.SUB, self.ret)).binary_op(BinaryOps.MUL, grad_output)
77+
return (self.ret * (1 - self.ret)) * grad_output
7978

8079
# ************* reduce ops *************
8180

@@ -96,56 +95,51 @@ def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer:
9695

9796
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
9897
# 1s in locations where the max was chosen (can be two locations)
99-
max_is_1s = self.x.const_like(1).binary_op(BinaryOps.SUB, self.x.binary_op(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
100-
101-
# sum of locations, averaged
98+
max_is_1s = 1.0 - (self.x < self.ret.expand(self.x.shape))
10299
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
103-
max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div)
104-
105-
grad_output_expanded = grad_output.expand(self.x.shape)
106-
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
100+
return (max_is_1s / div) * grad_output.expand(self.x.shape)
107101

108102
# ************* binary ops *************
109103

110104
class Less(Function):
111105
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
112-
return x.binary_op(BinaryOps.CMPLT, y)
106+
return x < y
113107

114108
class Add(Function):
115109
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
116-
return x.binary_op(BinaryOps.ADD, y)
110+
return x + y
117111

118112
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
119113
return grad_output if self.needs_input_grad[0] else None, \
120114
grad_output if self.needs_input_grad[1] else None
121115

122116
class Sub(Function):
123117
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
124-
return x.binary_op(BinaryOps.SUB, y)
118+
return x - y
125119

126120
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
127121
return grad_output if self.needs_input_grad[0] else None, \
128-
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output) if self.needs_input_grad[1] else None
122+
-grad_output if self.needs_input_grad[1] else None
129123

130124
class Mul(Function):
131125
__slots__ = 'x', 'y'
132126
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
133127
self.x, self.y = x, y
134-
return x.binary_op(BinaryOps.MUL, y)
128+
return x * y
135129

136130
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
137-
return self.y.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
138-
self.x.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
131+
return self.y * grad_output if self.needs_input_grad[0] else None, \
132+
self.x * grad_output if self.needs_input_grad[1] else None
139133

140134
class Div(Function):
141135
__slots__ = 'x', 'y'
142136
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
143137
self.x, self.y = x, y
144-
return x.binary_op(BinaryOps.DIV, y)
138+
return x / y
145139

146140
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
147-
return grad_output.binary_op(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
148-
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output).binary_op(BinaryOps.MUL, self.x).binary_op(BinaryOps.DIV, self.y.binary_op(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None
141+
return grad_output / self.y if self.needs_input_grad[0] else None, \
142+
(-grad_output * self.x) / (self.y * self.y) if self.needs_input_grad[1] else None
149143

150144
# ************* ternary ops *************
151145

@@ -157,8 +151,8 @@ def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
157151

158152
def backward(self, grad_output:LazyBuffer):
159153
return None, \
160-
self.x.ternary_op(TernaryOps.WHERE, grad_output, self.x.const_like(0)) if self.needs_input_grad[1] else None, \
161-
self.x.ternary_op(TernaryOps.WHERE, self.x.const_like(0), grad_output) if self.needs_input_grad[2] else None
154+
self.x.ternary_op(TernaryOps.WHERE, grad_output, 0) if self.needs_input_grad[1] else None, \
155+
self.x.ternary_op(TernaryOps.WHERE, 0, grad_output) if self.needs_input_grad[2] else None
162156

163157
# ************* movement ops *************
164158

0 commit comments

Comments
 (0)