Skip to content

Commit 98e7084

Browse files
authored
[Experimental] Add memory recording which can then be bulk-freed (#698)
1 parent 4385ed9 commit 98e7084

File tree

4 files changed

+54
-3
lines changed

4 files changed

+54
-3
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AMDGPU"
22
uuid = "21141c5a-9bdb-4563-92ae-f87d6854732e"
33
authors = ["Julian P Samaroo <jpsamaroo@jpsamaroo.me>", "Valentin Churavy <v.churavy@gmail.com>", "Anton Smirnov <tonysmn97@gmail.com>"]
4-
version = "1.1.0"
4+
version = "1.1.1"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/AMDGPU.jl

+1
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ include("tls.jl")
114114
include("highlevel.jl")
115115
include("reflection.jl")
116116
include("array.jl")
117+
include("memory_record.jl")
117118
include("conversions.jl")
118119
include("broadcast.jl")
119120
include("exception_handler.jl")

src/array.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N}
88
) where {T, N, B <: Mem.AbstractAMDBuffer}
99
@assert isbitstype(T) "ROCArray only supports bits types"
1010
data = DataRef(pool_free, pool_alloc(B, prod(dims) * sizeof(T)))
11-
xs = new{T, N, B}(data, dims, 0)
12-
return finalizer(unsafe_free!, xs)
11+
x = new{T, N, B}(data, dims, 0)
12+
x = finalizer(unsafe_free!, x)
13+
RECORD_MEMORY[] && record!(x)
14+
return x
1315
end
1416

1517
function ROCArray{T, N}(

src/memory_record.jl

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# NOTE: EXPERIMENTAL API.
2+
3+
const MemoryRecords = LockedObject(Dict{UInt64, ROCArray}())
4+
5+
const RECORD_MEMORY::Ref{Bool} = Ref(false)
6+
7+
function record_memory!(rec::Bool; free::Bool = true, sync::Bool = false)
8+
RECORD_MEMORY[] = rec
9+
if !rec
10+
free && free_records!(; sync)
11+
end
12+
return
13+
end
14+
15+
record_memory() = RECORD_MEMORY[]
16+
17+
function record!(x)
18+
Base.lock(records -> records[_hash(x)] = x, MemoryRecords)
19+
return
20+
end
21+
22+
function free_records!(; sync::Bool = false)
23+
Base.lock(MemoryRecords) do records
24+
for (k, x) in records
25+
unsafe_free!(x)
26+
end
27+
empty!(records)
28+
end
29+
sync && AMDGPU.synchronize()
30+
return
31+
end
32+
33+
function remove_record!(x)
34+
record_memory() || return
35+
36+
k = _hash(x)
37+
Base.lock(MemoryRecords) do records
38+
if k in records.keys
39+
pop!(records, k)
40+
end
41+
end
42+
return
43+
end
44+
45+
_hash(x::ROCArray) =
46+
Base.hash(x.buf.rc.obj.mem.ptr,
47+
Base.hash(x.offset,
48+
Base.hash(x.dims)))

0 commit comments

Comments
 (0)