1
1
from typing import Tuple , Optional
2
- from tinygrad .helpers import argsort , ShapeType
2
+ from tinygrad .helpers import argsort , ShapeType , DType
3
3
from tinygrad .ops import UnaryOps , BinaryOps , TernaryOps , ReduceOps
4
4
from tinygrad .tensor import Function
5
5
from tinygrad .lazy import LazyBuffer
@@ -11,50 +11,49 @@ def backward(self, grad_output): return grad_output
11
11
12
12
class Cast (Function ):
13
13
__slots__ = "input_dtype" , "bitcast"
14
- def forward (self , x , dtype , bitcast = False ):
14
+ def forward (self , x : LazyBuffer , dtype : DType , bitcast = False ):
15
15
self .input_dtype , self .bitcast = x .dtype , bitcast
16
16
return x .cast ((dtype , bitcast ))
17
- def backward (self , grad_output ):
17
+ def backward (self , grad_output : LazyBuffer ):
18
18
return grad_output .cast ((self .input_dtype , self .bitcast ))
19
19
20
20
# ************* unary ops *************
21
21
22
22
class Sin (Function ):
23
23
__slots__ = "x"
24
- def forward (self , x : LazyBuffer ) -> LazyBuffer :
24
+ def forward (self , x :LazyBuffer ) -> LazyBuffer :
25
25
self .x = x
26
26
return x .unary_op (UnaryOps .SIN )
27
- def backward (self , grad : LazyBuffer ) -> LazyBuffer :
28
- return self . x . const_like ( math .pi / 2 ). binary_op ( BinaryOps . SUB , self .x ).unary_op (UnaryOps .SIN ). binary_op ( BinaryOps . MUL , grad )
27
+ def backward (self , grad :LazyBuffer ) -> LazyBuffer :
28
+ return (( math .pi / 2 ) - self .x ).unary_op (UnaryOps .SIN ) * grad
29
29
30
30
# NOTE: maximum(x, 0) behaves differently where x=0
31
31
class Relu (Function ):
32
32
__slots__ = "ret"
33
33
def forward (self , x :LazyBuffer ) -> LazyBuffer :
34
- self .ret = x .binary_op (BinaryOps .MAX , x . const_like ( 0 ) )
34
+ self .ret = x .binary_op (BinaryOps .MAX , 0 )
35
35
return self .ret
36
36
37
37
def backward (self , grad_output :LazyBuffer ) -> LazyBuffer :
38
- mask = self .ret .const_like (0 ).binary_op (BinaryOps .CMPLT , self .ret )
39
- return mask .binary_op (BinaryOps .MUL , grad_output )
38
+ return (0 < self .ret ) * grad_output
40
39
41
40
class Log (Function ):
42
41
__slots__ = "x"
43
42
def forward (self , x :LazyBuffer ) -> LazyBuffer :
44
43
self .x = x
45
- return x .unary_op (UnaryOps .LOG2 ). binary_op ( BinaryOps . MUL , x . const_like ( math .log (2 )) )
44
+ return x .unary_op (UnaryOps .LOG2 ) * math .log (2 )
46
45
47
46
def backward (self , grad_output :LazyBuffer ) -> LazyBuffer :
48
- return grad_output . binary_op ( BinaryOps . DIV , self .x )
47
+ return grad_output / self .x
49
48
50
49
class Exp (Function ):
51
50
__slots__ = "ret"
52
51
def forward (self , x :LazyBuffer ) -> LazyBuffer :
53
- self .ret = x . binary_op ( BinaryOps . MUL , x . const_like (1 / math .log (2 ))).unary_op (UnaryOps .EXP2 )
52
+ self .ret = ( x * (1 / math .log (2 ))).unary_op (UnaryOps .EXP2 )
54
53
return self .ret
55
54
56
55
def backward (self , grad_output :LazyBuffer ) -> LazyBuffer :
57
- return self .ret . binary_op ( BinaryOps . MUL , grad_output )
56
+ return self .ret * grad_output
58
57
59
58
class Sqrt (Function ):
60
59
__slots__ = "ret"
@@ -63,19 +62,19 @@ def forward(self, x:LazyBuffer) -> LazyBuffer:
63
62
return self .ret
64
63
65
64
def backward (self , grad_output :LazyBuffer ) -> LazyBuffer :
66
- return grad_output . binary_op ( BinaryOps . DIV , self . ret . binary_op ( BinaryOps . MUL , self .ret . const_like ( 2 )) )
65
+ return grad_output / ( self .ret * 2 )
67
66
68
67
# NOTE: the implicit derivative of sigmoid is not stable
69
68
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
70
69
# TODO: have the backend automatically find this
71
70
class Sigmoid (Function ):
72
71
__slots__ = "ret"
73
72
def forward (self , x :LazyBuffer ) -> LazyBuffer :
74
- self .ret = x . const_like ( 1 ). binary_op ( BinaryOps . DIV , x . const_like ( 1 ). binary_op ( BinaryOps . ADD , x . binary_op ( BinaryOps . MUL , x . const_like (- 1 / math .log (2 ))).unary_op (UnaryOps .EXP2 ) ))
73
+ self .ret = 1 / ( 1 + ( x * (- 1 / math .log (2 ))).unary_op (UnaryOps .EXP2 ))
75
74
return self .ret
76
75
77
76
def backward (self , grad_output :LazyBuffer ) -> LazyBuffer :
78
- return self . ret . binary_op ( BinaryOps . MUL , self .ret . const_like ( 1 ). binary_op ( BinaryOps . SUB , self .ret )). binary_op ( BinaryOps . MUL , grad_output )
77
+ return ( self .ret * ( 1 - self .ret )) * grad_output
79
78
80
79
# ************* reduce ops *************
81
80
@@ -96,56 +95,51 @@ def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer:
96
95
97
96
def backward (self , grad_output :LazyBuffer ) -> LazyBuffer :
98
97
# 1s in locations where the max was chosen (can be two locations)
99
- max_is_1s = self .x .const_like (1 ).binary_op (BinaryOps .SUB , self .x .binary_op (BinaryOps .CMPLT , self .ret .expand (self .x .shape )))
100
-
101
- # sum of locations, averaged
98
+ max_is_1s = 1.0 - (self .x < self .ret .expand (self .x .shape ))
102
99
div = max_is_1s .reduce_op (ReduceOps .SUM , grad_output .shape ).expand (self .x .shape )
103
- max_is_amount = max_is_1s .binary_op (BinaryOps .DIV , div )
104
-
105
- grad_output_expanded = grad_output .expand (self .x .shape )
106
- return max_is_amount .binary_op (BinaryOps .MUL , grad_output_expanded )
100
+ return (max_is_1s / div ) * grad_output .expand (self .x .shape )
107
101
108
102
# ************* binary ops *************
109
103
110
104
class Less (Function ):
111
105
def forward (self , x :LazyBuffer , y :LazyBuffer ) -> LazyBuffer :
112
- return x . binary_op ( BinaryOps . CMPLT , y )
106
+ return x < y
113
107
114
108
class Add (Function ):
115
109
def forward (self , x :LazyBuffer , y :LazyBuffer ) -> LazyBuffer :
116
- return x . binary_op ( BinaryOps . ADD , y )
110
+ return x + y
117
111
118
112
def backward (self , grad_output :LazyBuffer ) -> Tuple [Optional [LazyBuffer ], Optional [LazyBuffer ]]:
119
113
return grad_output if self .needs_input_grad [0 ] else None , \
120
114
grad_output if self .needs_input_grad [1 ] else None
121
115
122
116
class Sub (Function ):
123
117
def forward (self , x :LazyBuffer , y :LazyBuffer ) -> LazyBuffer :
124
- return x . binary_op ( BinaryOps . SUB , y )
118
+ return x - y
125
119
126
120
def backward (self , grad_output :LazyBuffer ) -> Tuple [Optional [LazyBuffer ], Optional [LazyBuffer ]]:
127
121
return grad_output if self .needs_input_grad [0 ] else None , \
128
- grad_output . const_like ( 0 ). binary_op ( BinaryOps . SUB , grad_output ) if self .needs_input_grad [1 ] else None
122
+ - grad_output if self .needs_input_grad [1 ] else None
129
123
130
124
class Mul (Function ):
131
125
__slots__ = 'x' , 'y'
132
126
def forward (self , x :LazyBuffer , y :LazyBuffer ) -> LazyBuffer :
133
127
self .x , self .y = x , y
134
- return x . binary_op ( BinaryOps . MUL , y )
128
+ return x * y
135
129
136
130
def backward (self , grad_output :LazyBuffer ) -> Tuple [Optional [LazyBuffer ], Optional [LazyBuffer ]]:
137
- return self .y . binary_op ( BinaryOps . MUL , grad_output ) if self .needs_input_grad [0 ] else None , \
138
- self .x . binary_op ( BinaryOps . MUL , grad_output ) if self .needs_input_grad [1 ] else None
131
+ return self .y * grad_output if self .needs_input_grad [0 ] else None , \
132
+ self .x * grad_output if self .needs_input_grad [1 ] else None
139
133
140
134
class Div (Function ):
141
135
__slots__ = 'x' , 'y'
142
136
def forward (self , x :LazyBuffer , y :LazyBuffer ) -> LazyBuffer :
143
137
self .x , self .y = x , y
144
- return x . binary_op ( BinaryOps . DIV , y )
138
+ return x / y
145
139
146
140
def backward (self , grad_output :LazyBuffer ) -> Tuple [Optional [LazyBuffer ], Optional [LazyBuffer ]]:
147
- return grad_output . binary_op ( BinaryOps . DIV , self .y ) if self .needs_input_grad [0 ] else None , \
148
- grad_output . const_like ( 0 ). binary_op ( BinaryOps . SUB , grad_output ). binary_op ( BinaryOps . MUL , self .x ). binary_op ( BinaryOps . DIV , self .y . binary_op ( BinaryOps . MUL , self .y ) ) if self .needs_input_grad [1 ] else None
141
+ return grad_output / self .y if self .needs_input_grad [0 ] else None , \
142
+ ( - grad_output * self .x ) / ( self .y * self .y ) if self .needs_input_grad [1 ] else None
149
143
150
144
# ************* ternary ops *************
151
145
@@ -157,8 +151,8 @@ def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
157
151
158
152
def backward (self , grad_output :LazyBuffer ):
159
153
return None , \
160
- self .x .ternary_op (TernaryOps .WHERE , grad_output , self . x . const_like ( 0 ) ) if self .needs_input_grad [1 ] else None , \
161
- self .x .ternary_op (TernaryOps .WHERE , self . x . const_like ( 0 ) , grad_output ) if self .needs_input_grad [2 ] else None
154
+ self .x .ternary_op (TernaryOps .WHERE , grad_output , 0 ) if self .needs_input_grad [1 ] else None , \
155
+ self .x .ternary_op (TernaryOps .WHERE , 0 , grad_output ) if self .needs_input_grad [2 ] else None
162
156
163
157
# ************* movement ops *************
164
158
0 commit comments