|
| 1 | +# NOTE: EXPERIMENTAL API. |
| 2 | + |
| 3 | +struct CacheAllocator |
| 4 | + lock::ReentrantLock |
| 5 | + busy::Dict{UInt64, Vector{ROCArray}} # hash((T, dims)) => ROCArray[] |
| 6 | + free::Dict{UInt64, Vector{ROCArray}} |
| 7 | +end |
| 8 | + |
| 9 | +CacheAllocator() = CacheAllocator( |
| 10 | + ReentrantLock(), |
| 11 | + Dict{UInt64, Vector{ROCArray}}(), |
| 12 | + Dict{UInt64, Vector{ROCArray}}(), |
| 13 | +) |
| 14 | + |
| 15 | +const CACHE_ALLOCS::LockedObject{Dict{Symbol, CacheAllocator}} = |
| 16 | + LockedObject(Dict{Symbol, CacheAllocator}()) |
| 17 | + |
| 18 | +function cache_allocator!(cache_name::Symbol) |
| 19 | + allocs = CACHE_ALLOCS.payload |
| 20 | + alloc = get(allocs, cache_name, nothing) |
| 21 | + alloc ≡ nothing || return alloc |
| 22 | + |
| 23 | + return Base.@lock CACHE_ALLOCS.lock begin |
| 24 | + allocs[cache_name] = CacheAllocator() |
| 25 | + end |
| 26 | +end |
| 27 | + |
| 28 | +function get_free_pool(alloc::CacheAllocator, uid) |
| 29 | + free_pool = get(alloc.free, uid, nothing) |
| 30 | + if free_pool ≡ nothing |
| 31 | + free_pool = Base.@lock alloc.lock alloc.free[uid] = ROCArray[] |
| 32 | + end |
| 33 | + return free_pool |
| 34 | +end |
| 35 | + |
| 36 | +function get_busy_pool(alloc::CacheAllocator, uid) |
| 37 | + busy_pool = get(alloc.busy, uid, nothing) |
| 38 | + if busy_pool ≡ nothing |
| 39 | + busy_pool = Base.@lock alloc.lock alloc.busy[uid] = ROCArray[] |
| 40 | + end |
| 41 | + return busy_pool |
| 42 | +end |
| 43 | + |
| 44 | +function alloc!( |
| 45 | + alloc::CacheAllocator, ::Type{Mem.HIPBuffer}, ::Type{T}, dims::Dims{N}, |
| 46 | +)::Maybe{ROCArray{T, N, Mem.HIPBuffer}} where {T, N} |
| 47 | + uid = hash((T, dims)) |
| 48 | + free_pool = get_free_pool(alloc, uid) |
| 49 | + isempty(free_pool) && return nothing |
| 50 | + |
| 51 | + # @info "Cache hit" |
| 52 | + busy_pool = get_busy_pool(alloc, uid) |
| 53 | + x = pop!(free_pool) |
| 54 | + # Array was manually freed via `unsafe_free!`. |
| 55 | + x.buf.freed && return nothing |
| 56 | + |
| 57 | + push!(busy_pool, x) |
| 58 | + return x |
| 59 | +end |
| 60 | + |
| 61 | +# Mark `x` array as busy, used during cache misses to add new allocations. |
| 62 | +function add_busy!(alloc::CacheAllocator, x::ROCArray{T}) where T |
| 63 | + uid = hash((T, size(x))) |
| 64 | + busy_pool = get_busy_pool(alloc, uid) |
| 65 | + Base.@lock alloc.lock push!(busy_pool, x) |
| 66 | + return |
| 67 | +end |
| 68 | + |
| 69 | +function free_busy!(alloc::CacheAllocator) |
| 70 | + for uid in alloc.busy.keys |
| 71 | + free_pool = get_free_pool(alloc, uid) |
| 72 | + busy_pool = get_busy_pool(alloc, uid) |
| 73 | + isempty(busy_pool) && continue |
| 74 | + |
| 75 | + Base.@lock alloc.lock begin |
| 76 | + append!(free_pool, busy_pool) |
| 77 | + empty!(busy_pool) |
| 78 | + end |
| 79 | + end |
| 80 | +end |
| 81 | + |
| 82 | +# Public API. |
| 83 | + |
| 84 | +""" |
| 85 | + with_caching_allocator(f, alloc_name::Symbol, args...) |
| 86 | +
|
| 87 | +Execute function `f` with arguments `args...` using |
| 88 | +caching allocator given by its name `alloc_name`. |
| 89 | +
|
| 90 | +All GPU memory allocations will attempt to hit this cache |
| 91 | +before doing actual allocation (in case of cache miss). |
| 92 | +After executing `f`, all "busy" memory within the allocator is marked as free, |
| 93 | +so it can be re-used with the next call. |
| 94 | +
|
| 95 | +# Returns |
| 96 | +
|
| 97 | +Result of the `f` function. |
| 98 | +""" |
| 99 | +function with_caching_allocator(f, alloc_name::Symbol, args...) |
| 100 | + alloc = cache_allocator!(alloc_name) |
| 101 | + # Enable usage of cache allocator during allocations. |
| 102 | + cache_alloc_name!(alloc_name) |
| 103 | + res = f(args...) |
| 104 | + # Mark all allocations during `f` as free to re-use and disable allocator. |
| 105 | + free_busy!(alloc) |
| 106 | + cache_alloc_name!(:none) |
| 107 | + return res |
| 108 | +end |
| 109 | + |
| 110 | +""" |
| 111 | + with_no_caching(f) |
| 112 | +
|
| 113 | +Execute function `f`, but avoid hitting any caching allocator. |
| 114 | +This is useful to call from within [`with_caching_allocator`](@ref), |
| 115 | +so that the memory is independent from it. |
| 116 | +
|
| 117 | +# Returns |
| 118 | +
|
| 119 | +Result of the `f` function. |
| 120 | +""" |
| 121 | +function with_no_caching(f) |
| 122 | + alloc_name = cache_alloc_name() |
| 123 | + cache_alloc_name!(:none) |
| 124 | + res = f() |
| 125 | + cache_alloc_name!(alloc_name) |
| 126 | + return res |
| 127 | +end |
| 128 | + |
| 129 | +""" |
| 130 | + invalidate_caching_allocator!(alloc_name::Symbol) |
| 131 | +
|
| 132 | +Free all memory held by caching allocator given by it name `alloc_name`. |
| 133 | +""" |
| 134 | +function invalidate_caching_allocator!(alloc_name::Symbol) |
| 135 | + alloc = cache_allocator!(alloc_name) |
| 136 | + alloc ≡ nothing && return |
| 137 | + |
| 138 | + Base.@lock alloc.lock begin |
| 139 | + for (_, pool) in alloc.free |
| 140 | + map(AMDGPU.unsafe_free!, pool) |
| 141 | + end |
| 142 | + # TODO is other threads use the same, signal that it is invalidated somehow? |
| 143 | + # TODO error if pool is in use, i.e. non empty `busy`? |
| 144 | + for (_, pool) in alloc.busy |
| 145 | + map(AMDGPU.unsafe_free!, pool) |
| 146 | + end |
| 147 | + empty!(alloc.busy) |
| 148 | + empty!(alloc.free) |
| 149 | + end |
| 150 | + return |
| 151 | +end |
0 commit comments