5
5
from typing import List , Tuple , Callable , Optional , ClassVar , Type , Union , Sequence
6
6
from tinygrad .helpers import prod , argfix , make_pair , getenv , DEBUG , flatten
7
7
from tinygrad .lazy import Device , LazyBuffer , LazyNumpyArray
8
- from tinygrad .image import image_conv2d_decorator
8
+ from tinygrad .image import image_conv2d_decorator , image_dot_decorator
9
9
10
10
# An instantiation of the Function is the Context
11
11
class Function :
@@ -252,8 +252,10 @@ def unsqueeze(self, dim):
252
252
253
253
# (padding_left, padding_right, padding_top, padding_bottom)
254
254
def pad2d (self , padding :Tuple [int , ...]): return self .slice (((0 ,self .shape [0 ]), (0 ,self .shape [1 ]), (- padding [2 ],self .shape [2 ]+ padding [3 ]), (- padding [0 ],self .shape [3 ]+ padding [1 ])))
255
- # TODO: this is totally not transpose
256
- def transpose (self , order = (1 ,0 )) -> Tensor : return self .permute (order = order )
255
+ def transpose (self , ax1 = 1 , ax2 = 0 ) -> Tensor :
256
+ order = list (range (len (self .shape )))
257
+ order [ax1 ], order [ax2 ] = order [ax2 ], order [ax1 ]
258
+ return self .permute (order )
257
259
def flatten (self , start_dim = 0 ): return self .reshape (shape = tuple (list (self .shape [0 :start_dim ]) + [- 1 ]))
258
260
259
261
# ***** reduce ops *****
@@ -335,23 +337,11 @@ def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1,
335
337
ret = (x * weight .reshape (1 , groups , rcout , 1 , 1 , cin , H , W )).sum ((- 3 , - 2 , - 1 )).reshape (bs , cout , oy , ox )
336
338
return ret if bias is None else ret .add (bias .reshape (1 , - 1 , 1 , 1 ))
337
339
340
+ @image_dot_decorator
338
341
def dot (self , w :Tensor ) -> Tensor :
339
- # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
340
- bs , groups = prod (self .shape [0 :- 2 ]), prod (w .shape [0 :- 2 ])
341
- cin , cout = w .shape [- 2 ], w .shape [- 1 ]
342
- out_shape_t = self .shape [0 :- 2 ] + (cout ,- 1 )
343
- if len (self .shape ) > 1 :
344
- order = tuple (range (len (self .shape )- 2 )) + (len (self .shape )- 1 , len (self .shape )- 2 )
345
- else :
346
- order , out_shape_t = (0 ,), (cout , )
347
- worder = tuple (range (len (w .shape )- 2 )) + (len (w .shape )- 1 , len (w .shape )- 2 )
348
-
349
- # NOTE: with NHWC we can remove the transposes
350
- # bs x groups*cin x H x W
351
- cx = self .transpose (order = order ).reshape (shape = (bs // groups , groups * cin , - 1 , 1 ))
352
- # groups*cout x cin x H, W
353
- cw = w .transpose (order = worder ).reshape (shape = (groups * cout , cin , 1 , 1 ))
354
- return cx .conv2d (cw , groups = groups ).reshape (shape = out_shape_t ).transpose (order = order )
342
+ x = self .reshape (* self .shape [0 :- 1 ], 1 , self .shape [- 1 ])
343
+ w = w .reshape (* w .shape [0 :- 2 ], 1 , w .shape [- 2 ], w .shape [- 1 ]).transpose (- 1 , - 2 )
344
+ return (x * w ).sum (- 1 ).reshape (* x .shape [0 :- 2 ], - 1 )
355
345
356
346
# ***** mlops (unary) *****
357
347
@@ -363,6 +353,7 @@ def exp(self): return mlops.Exp.apply(self)
363
353
364
354
def __neg__ (self ): return 0.0 - self
365
355
def sqrt (self ): return self .pow (0.5 )
356
+ def rsqrt (self ): return self .pow (- 0.5 )
366
357
def square (self ): return self * self
367
358
def clip (self , min_ , max_ ): return ((self - min_ ).relu ()+ min_ ) - (self - max_ ).relu ()
368
359
def abs (self ): return self .relu () + (- self ).relu ()
0 commit comments