Skip to content

Commit f836632

Browse files
authored
Merge pull request #116 from JuliaGPU/jps/mark-wait
Add mark/wait synchronization system
2 parents 91e61e1 + 4432499 commit f836632

19 files changed

+219
-47
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AMDGPU"
22
uuid = "21141c5a-9bdb-4563-92ae-f87d6854732e"
33
authors = ["Julian P Samaroo <jpsamaroo@jpsamaroo.me>"]
4-
version = "0.2.5"
4+
version = "0.2.6"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/AMDGPU.jl

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ if get(ENV, "AMDGPUNATIVE_OPENCL", "") != ""
4848
end
4949
=#
5050
include("runtime.jl")
51+
include("sync.jl")
5152

5253
# Device sources must load _before_ the compiler infrastructure
5354
# because of generated functions.

src/array.jl

+22-2
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,18 @@ end
5353
# Host abstractions
5454
#
5555

56-
5756
mutable struct ROCArray{T,N} <: AbstractGPUArray{T,N}
5857
buf::Mem.Buffer
5958
own::Bool
6059

6160
dims::Dims{N}
6261
offset::Int
6362

63+
syncstate::SyncState
64+
6465
function ROCArray{T,N}(buf::Mem.Buffer, dims::Dims{N}; offset::Integer=0, own::Bool=true) where {T,N}
6566
@assert isbitstype(T) "ROCArray only supports bits types"
66-
xs = new{T,N}(buf, own, dims, offset)
67+
xs = new{T,N}(buf, own, dims, offset, SyncState())
6768
if own
6869
hsaref!()
6970
Mem.retain(buf)
@@ -79,6 +80,21 @@ function unsafe_free!(xs::ROCArray)
7980
return
8081
end
8182

83+
wait!(x::ROCArray) = wait!(x.syncstate)
84+
mark!(x::ROCArray, s) = mark!(x.syncstate, s)
85+
wait!(xs::Vector{<:ROCArray}) = foreach(wait!, xs)
86+
mark!(xs::Vector{<:ROCArray}, s) = foreach(x->mark!(x,s), xs)
87+
wait!(xs::NTuple{N,<:ROCArray} where N) = foreach(wait!, xs)
88+
mark!(xs::NTuple{N,<:ROCArray} where N, s) = foreach(x->mark!(x,s), xs)
89+
function Adapt.adapt_storage(::WaitAdaptor, x::ROCArray)
90+
wait!(x.syncstate)
91+
x
92+
end
93+
function Adapt.adapt_storage(ma::MarkAdaptor, x::ROCArray)
94+
mark!(x.syncstate, ma.s)
95+
x
96+
end
97+
8298
## aliases
8399

84100
const ROCVector{T} = ROCArray{T,1}
@@ -154,6 +170,7 @@ function Base.copyto!(dest::Array{T}, d_offset::Integer,
154170
amount == 0 && return dest
155171
@boundscheck checkbounds(dest, d_offset+amount-1)
156172
@boundscheck checkbounds(source, s_offset+amount-1)
173+
wait!(source)
157174
Mem.download!(pointer(dest, d_offset),
158175
Mem.view(source.buf, (s_offset-1)*sizeof(T)),
159176
amount*sizeof(T))
@@ -165,6 +182,7 @@ function Base.copyto!(dest::ROCArray{T}, d_offset::Integer,
165182
amount == 0 && return dest
166183
@boundscheck checkbounds(dest, d_offset+amount-1)
167184
@boundscheck checkbounds(source, s_offset+amount-1)
185+
wait!(dest)
168186
Mem.upload!(Mem.view(dest.buf, (d_offset-1)*sizeof(T)),
169187
pointer(source, s_offset),
170188
amount*sizeof(T))
@@ -176,6 +194,8 @@ function Base.copyto!(dest::ROCArray{T}, d_offset::Integer,
176194
amount == 0 && return dest
177195
@boundscheck checkbounds(dest, d_offset+amount-1)
178196
@boundscheck checkbounds(source, s_offset+amount-1)
197+
wait!(dest)
198+
wait!(source)
179199
Mem.transfer!(Mem.view(dest.buf, (d_offset-1)*sizeof(T)),
180200
Mem.view(source.buf, (s_offset-1)*sizeof(T)),
181201
amount*sizeof(T))

src/blas/rocBLAS.jl

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module rocBLAS
22

33
using ..AMDGPU
4+
import AMDGPU: wait!, mark!
45

56
using LinearAlgebra
67

0 commit comments

Comments
 (0)