Skip to content

Commit 8c937e3

Browse files
authored
Fix cache retrieval (#718)
1 parent b6f3376 commit 8c937e3

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AMDGPU"
22
uuid = "21141c5a-9bdb-4563-92ae-f87d6854732e"
33
authors = ["Julian P Samaroo <jpsamaroo@jpsamaroo.me>", "Valentin Churavy <v.churavy@gmail.com>", "Anton Smirnov <tonysmn97@gmail.com>"]
4-
version = "1.2.0"
4+
version = "1.2.1"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/array.jl

+6-1
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@ mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N}
66
function ROCArray{T, N, B}(::UndefInitializer, dims::Dims{N}) where {T, N, B <: Mem.AbstractAMDBuffer}
77
@assert isbitstype(T) "ROCArray only supports bits types"
88
sz::Int64 = prod(dims) * sizeof(T)
9-
return GPUArrays.cached_alloc((ROCArray, AMDGPU.device(), T, B, sz)) do
9+
x = GPUArrays.cached_alloc((ROCArray, AMDGPU.device(), T, B, sz)) do
1010
@debug "Allocate `T=$T`, `dims=$dims`: $(Base.format_bytes(sz))"
1111
data = DataRef(pool_free, pool_alloc(B, sz))
1212
return finalizer(unsafe_free!, new{T, N, B}(data, dims, 0))
13+
end
14+
return if size(x) != dims
15+
reshape(x, dims)
16+
else
17+
x
1318
end::ROCArray{T, N, B}
1419
end
1520

0 commit comments

Comments
 (0)