-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmydot.jl
61 lines (53 loc) · 1.82 KB
/
mydot.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
using MappedArrays, StructArrays
using UnzipLoops
using LinearAlgebra
using BenchmarkTools
# Provided by @mcabbott and @jishnub
"""
@unzip f.(g.(x))
Like `@unzipcast` but without `@.`, thus expects a broadcasting expression.
"""
macro unzip(ex)
bc = esc(ex)
:(_unz.($bc))
end
function _unz end # this is never called
Broadcast.broadcasted(::typeof(_unz), x) = _Unz(x)
struct _Unz{T}; bc::T; end
Broadcast.materialize(x::_Unz) = StructArrays.components(StructArray(Broadcast.instantiate(x.bc)))
f(x, y) = (x^y, x/y)
# Baseline
@inline function mydot_v1(f, X, Y)
tmp = broadcast(f, X, Y)
return dot(map(first, tmp), map(last, tmp))
end
# UnzipLoops
@inline function mydot_v2(f, X, Y)
A, B = broadcast_unzip(f, X, Y)
return dot(A, B)
end
# StructArrays
@inline function mydot_v3(f, X, Y)
A, B = @unzip f.(X, Y)
return dot(A, B)
end
# vector case
X, Y = rand(1024), rand(1024);
mydot_v1(f, X, Y) ≈ mydot_v2(f, X, Y) ≈ mydot_v3(f, X, Y)
# essential computation time
@btime f.($X, $Y); # 32.636 μs (1 allocation: 16.12 KiB)
@btime dot($X, $Y); # 51.022 ns (0 allocations: 0 bytes)
# unzip overhead: 6μs vs 0μs vs 0μs
@btime mydot_v1(f, $X, $Y); # 38.929 μs (3 allocations: 32.38 KiB)
@btime mydot_v2(f, $X, $Y); # 32.067 μs (2 allocations: 16.25 KiB)
@btime mydot_v3(f, $X, $Y); # 32.311 μs (2 allocations: 16.25 KiB)
# matrix case
X, Y = rand(1024), collect(rand(1024)');
mydot_v1(f, X, Y) ≈ mydot_v2(f, X, Y) ≈ mydot_v3(f, X, Y)
# essential computation time
@btime f.($X, $Y); # 32.766 ms (2 allocations: 16.00 MiB)
@btime dot($X, $Y); # 51.138 ns (0 allocations: 0 bytes)
# unzip overhead: 11ms vs 7ms vs 13ms
@btime mydot_v1(f, $X, $Y); # 43.507 ms (6 allocations: 32.00 MiB)
@btime mydot_v2(f, $X, $Y); # 39.685 ms (4 allocations: 16.00 MiB)
@btime mydot_v3(f, $X, $Y); # 45.969 ms (4 allocations: 16.00 MiB)