Skip to content

Commit 0602de8

Browse files
authored
Update to latest alloc cache (#723)
1 parent 0659e31 commit 0602de8

File tree

4 files changed

+10
-9
lines changed

4 files changed

+10
-9
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ CEnum = "0.4, 0.5"
4949
ChainRulesCore = "1"
5050
EnzymeCore = "0.8"
5151
ExprTools = "0.1"
52-
GPUArrays = "11.2"
52+
GPUArrays = "11.2.1"
5353
GPUCompiler = "0.27, 1.0"
5454
KernelAbstractions = "0.9.2"
5555
LLD_jll = "15, 16, 17"

src/array.jl

+3-8
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,11 @@ mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N}
66
function ROCArray{T, N, B}(::UndefInitializer, dims::Dims{N}) where {T, N, B <: Mem.AbstractAMDBuffer}
77
@assert isbitstype(T) "ROCArray only supports bits types"
88
sz::Int64 = prod(dims) * sizeof(T)
9-
x = GPUArrays.cached_alloc((ROCArray, AMDGPU.device(), T, B, sz)) do
9+
ref = GPUArrays.cached_alloc((ROCArray, AMDGPU.device(), B, sz)) do
1010
@debug "Allocate `T=$T`, `dims=$dims`: $(Base.format_bytes(sz))"
11-
data = DataRef(pool_free, pool_alloc(B, sz))
12-
return finalizer(unsafe_free!, new{T, N, B}(data, dims, 0))
11+
DataRef(pool_free, pool_alloc(B, sz))
1312
end
14-
return if size(x) != dims
15-
reshape(x, dims)
16-
else
17-
x
18-
end::ROCArray{T, N, B}
13+
return finalizer(unsafe_free!, new{T, N, B}(ref, dims, 0))
1914
end
2015

2116
function ROCArray{T, N}(buf::DataRef{Managed{B}}, dims::Dims{N}; offset::Integer = 0) where {T, N, B <: Mem.AbstractAMDBuffer}

src/memory.jl

+2
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ function synchronize(m::Managed)
217217
return
218218
end
219219

220+
Base.sizeof(m::Managed) = sizeof(m.mem)
221+
220222
function Base.convert(::Type{Ptr{T}}, managed::Managed{M}) where {T, M}
221223
strm = AMDGPU.stream()
222224

src/runtime/memory/hip.jl

+4
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ function HIPBuffer(ptr::Ptr{Cvoid}, bytesize::Int)
6969
HIPBuffer(s.device, s.ctx, ptr, bytesize, false)
7070
end
7171

72+
Base.sizeof(b::HIPBuffer) = UInt64(b.bytesize)
73+
7274
Base.convert(::Type{Ptr{T}}, buf::HIPBuffer) where T = convert(Ptr{T}, buf.ptr)
7375

7476
function view(buf::HIPBuffer, bytesize::Int)
@@ -137,6 +139,8 @@ function HostBuffer(
137139
HostBuffer(stream.device, stream.ctx, ptr, dev_ptr, sz, false)
138140
end
139141

142+
Base.sizeof(b::HostBuffer) = UInt64(b.bytesize)
143+
140144
function view(buf::HostBuffer, bytesize::Int)
141145
bytesize > buf.bytesize && throw(BoundsError(buf, bytesize))
142146
HostBuffer(

0 commit comments

Comments
 (0)