Skip to content

Commit 79f9d4e

Browse files
authored
Merge pull request #141 from exanauts/ms/active_kernels
Add support for active_kernels
2 parents 62f88d4 + 19be255 commit 79f9d4e

14 files changed

+139
-64
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AMDGPU"
22
uuid = "21141c5a-9bdb-4563-92ae-f87d6854732e"
33
authors = ["Julian P Samaroo <jpsamaroo@jpsamaroo.me>"]
4-
version = "0.2.11"
4+
version = "0.2.12"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

deps/loaddeps.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ try
1010
include("ext.jl")
1111
catch err
1212
if !isfile(joinpath(@__DIR__, "ext.jl"))
13-
@warn "Didn't find $ext, please build AMDGPU.jl"
13+
@warn "Didn't find deps/ext.jl, please build AMDGPU.jl"
1414
@eval const hsa_configured = false
1515
@eval const hip_configured = false
1616
@eval const device_libs_configured = false

docs/src/queues_signals.md

+6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ q = AMDGPU.HSAQueue(agent)
1212
@roc queue=q kernel(...)
1313
```
1414

15+
If you want to query which kernels are currently executing on a given queue,
16+
calling `AMDGPU.active_kernels(queue)` will return a `Vector{HSAStatusSignal}`,
17+
which can be inspected to determine how many (and which) kernels are executing
18+
by comparing the signals returned from `@roc`. You can also omit the `queue`
19+
argument, which will then check the default queue.
20+
1521
If a kernel ever gets "stuck" and locks up the GPU (noticeable with 100% GPU
1622
usage in `rocm-smi`), you can kill the kernel and all other kernels in the
1723
queue with `kill_queue!(queue)`. This can be "safely" done to the default

src/AMDGPU.jl

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ import .HSA: Agent, Queue, Executable, Status, Signal
3131

3232
struct Adaptor end
3333

34+
const RT_LOCK = Threads.SpinLock()
35+
3436
include("extras.jl")
3537
include("error.jl")
3638
include("agent.jl")
@@ -48,6 +50,7 @@ if get(ENV, "AMDGPUNATIVE_OPENCL", "") != ""
4850
end
4951
=#
5052
include("runtime.jl")
53+
include("statussignal.jl")
5154
include("sync.jl")
5255

5356
# Device sources must load _before_ the compiler infrastructure
@@ -60,6 +63,7 @@ include(joinpath("device", "runtime.jl"))
6063
include(joinpath("device", "llvm.jl"))
6164
include(joinpath("device", "globals.jl"))
6265

66+
include("query.jl")
6367
include("compiler.jl")
6468
include("execution_utils.jl")
6569
include("execution.jl")

src/exceptions.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ struct KernelException <: Exception
66
dev::RuntimeDevice
77
exstr::Union{String,Nothing}
88
end
9-
KernelException(dev) = KernelException(dev, nothing)
9+
KernelException(dev::RuntimeDevice) = KernelException(dev::RuntimeDevice, nothing)
10+
KernelException(agent::HSAAgent, exstr=nothing) = KernelException(RuntimeDevice(agent), exstr)
1011

1112
function Base.showerror(io::IO, err::KernelException)
1213
print(io, "KernelException: exception(s) thrown during kernel execution on device $(err.dev.device)")

src/execution.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ macro roc(ex...)
192192
local $kernel = $rocfunction($f, $kernel_tt; $(compiler_kwargs...))
193193
foreach($wait!, ($(var_exprs...),))
194194
if $launch
195-
local $signal = $create_event($kernel.mod.exe)
195+
local $signal = $create_event($kernel.mod.exe; $(call_kwargs...))
196196
$kernel($kernel_args...; signal=$signal, $(call_kwargs...))
197197
foreach(x->$mark!(x, $signal), ($(var_exprs...),))
198198
$signal

src/execution_utils.jl

+6
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,12 @@ unpreserve!(ev::RuntimeEvent) = unpreserve!(ev.event)
204204
kern = create_kernel(get_device(queue), f.mod.exe, f.entry, args)
205205

206206
# launch kernel
207+
lock($RT_LOCK)
208+
try
209+
push!($_active_kernels[queue.queue], signal.event)
210+
finally
211+
unlock($RT_LOCK)
212+
end
207213
launch_kernel(queue, kern, signal;
208214
groupsize=groupsize, gridsize=gridsize)
209215

src/query.jl

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
const _active_kernels = IdDict{HSAQueue,Vector{AMDGPU.HSAStatusSignal}}()
2+
3+
"""
4+
active_kernels(queue) -> Vector{AMDGPU.HSAStatusSignal}
5+
6+
Returns the set of actively-executing kernels on `queue`.
7+
"""
8+
function active_kernels(queue::HSAQueue=get_default_queue())
9+
lock(RT_LOCK) do
10+
copy(_active_kernels[queue])
11+
end
12+
end
13+
active_kernels(queue::RuntimeQueue) = active_kernels(queue.queue)

src/queue.jl

+7
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,18 @@ function HSAQueue(agent::HSAAgent)
2121
C_NULL, C_NULL, typemax(UInt32), typemax(UInt32),
2222
queue.queue) |> check
2323

24+
lock(RT_LOCK) do
25+
_active_kernels[queue] = Vector{AMDGPU.RuntimeEvent{AMDGPU.HSAStatusSignal}}()
26+
end
27+
2428
hsaref!()
2529
finalizer(queue) do queue
2630
if queue.active
2731
HSA.queue_destroy(queue.queue[]) |> check
2832
end
33+
lock(RT_LOCK) do
34+
delete!(_active_kernels, queue)
35+
end
2936
hsaunref!()
3037
end
3138
return queue

src/runtime.jl

+5-59
Original file line numberDiff line numberDiff line change
@@ -19,63 +19,8 @@ default_isa(device::RuntimeDevice{HSAAgent}) =
1919
struct RuntimeEvent{E}
2020
event::E
2121
end
22-
create_event(exe) = RuntimeEvent(create_event(RUNTIME[], exe))
23-
Base.wait(event::RuntimeEvent, exe) = wait(event.event, exe)
24-
25-
"Tracks the completion and status of a kernel's execution."
26-
struct HSAStatusSignal
27-
signal::HSASignal
28-
exe::HSAExecutable
29-
end
30-
create_event(::typeof(HSA_rt), exe) = HSAStatusSignal(HSASignal(), exe.exe)
31-
Base.wait(event::HSAStatusSignal; kwargs...) = wait(RuntimeEvent(event); kwargs...)
32-
function Base.wait(event::RuntimeEvent{HSAStatusSignal}; check_exceptions=true, cleanup=true, kwargs...)
33-
wait(event.event.signal; kwargs...) # wait for completion signal
34-
unpreserve!(event) # allow kernel-associated objects to be freed
35-
exe = event.event.exe::HSAExecutable{Mem.Buffer}
36-
mod = EXE_TO_MODULE_MAP[exe].value
37-
agent = exe.agent
38-
ex = nothing
39-
signal_handle = (event.event.signal.signal[]::HSA.Signal).handle
40-
if haskey(exe.globals, :__global_exception_flag)
41-
# Check if any wavefront for this kernel threw an exception
42-
ex_flag = get_global(exe, :__global_exception_flag)
43-
ex_flag_ptr = Base.unsafe_convert(Ptr{Int64}, ex_flag)
44-
ex_flag_value = Base.unsafe_load(ex_flag_ptr)
45-
if ex_flag_value != 0
46-
ex_strings = String[]
47-
if check_exceptions && haskey(exe.globals, :__global_exception_ring)
48-
# Check for and collect any exceptions, and clear their slots
49-
ex_ring = get_global(exe, :__global_exception_ring)
50-
ex_ring_ptr_ptr = Base.unsafe_convert(Ptr{Ptr{ExceptionEntry}}, ex_ring)
51-
ex_ring_ptr = unsafe_load(ex_ring_ptr_ptr)
52-
while (ex_ring_value = unsafe_load(ex_ring_ptr)).kern != 1
53-
if ex_ring_value.kern == signal_handle
54-
push!(ex_strings, unsafe_string(reinterpret(Ptr{UInt8}, ex_ring_value.ptr)))
55-
# FIXME: Write rest of entry first, then CAS 0 to kern field
56-
unsafe_store!(ex_ring_ptr, ExceptionEntry(UInt64(0), LLVMPtr{UInt8,1}(0)))
57-
end
58-
ex_ring_ptr += sizeof(ExceptionEntry)
59-
end
60-
end
61-
unique!(ex_strings)
62-
ex = KernelException(RuntimeDevice(agent), isempty(ex_strings) ? nothing : join(ex_strings, '\n'))
63-
end
64-
end
65-
if cleanup
66-
# Clean-up malloc'd data
67-
for idx in length(mod.metadata):-1:1
68-
metadata_value = mod.metadata[idx]
69-
if metadata_value.kern == signal_handle
70-
@debug "Cleaning up data: $metadata_value"
71-
Mem.free(metadata_value.buf)
72-
deleteat!(mod.metadata, idx)
73-
end
74-
end
75-
end
76-
ex !== nothing && throw(ex)
77-
end
78-
22+
create_event(exe; kwargs...) = RuntimeEvent(create_event(RUNTIME[], exe; kwargs...))
23+
Base.wait(event::RuntimeEvent) = wait(event.event)
7924

8025
struct RuntimeExecutable{E}
8126
exe::E
@@ -100,6 +45,9 @@ end
10045
get_global(exe::RuntimeExecutable, sym::Symbol) =
10146
get_global(exe.exe, sym)
10247

48+
create_event(::typeof(HSA_rt), exe::RuntimeExecutable{<:HSAExecutable}; queue=default_queue(), kwargs...) =
49+
HSAStatusSignal(HSASignal(), exe.exe, queue.queue)
50+
10351
struct RuntimeKernel{K}
10452
kernel::K
10553
end
@@ -116,6 +64,4 @@ function launch_kernel(::typeof(HSA_rt), queue, kern, event;
11664
workgroup_size=groupsize, grid_size=gridsize)
11765
end
11866
barrier_and!(queue, events::Vector{<:RuntimeEvent}) = barrier_and!(queue, map(x->x.event,events))
119-
barrier_and!(queue, signals::Vector{HSAStatusSignal}) = barrier_and!(queue, map(x->x.signal,signals))
12067
barrier_or!(queue, events::Vector{<:RuntimeEvent}) = barrier_or!(queue, map(x->x.event,events))
121-
barrier_or!(queue, signals::Vector{HSAStatusSignal}) = barrier_or!(queue, map(x->x.signal,signals))

src/statussignal.jl

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"Tracks the completion and status of a kernel's execution."
2+
mutable struct HSAStatusSignal
3+
signal::HSASignal
4+
exe::HSAExecutable
5+
queue::HSAQueue
6+
done::Base.Event
7+
exception::Union{Exception,Nothing}
8+
function HSAStatusSignal(signal::HSASignal, exe::HSAExecutable, queue::HSAQueue; kwargs...)
9+
signal = new(signal, exe, queue, Base.Event(), nothing)
10+
@async _wait(signal; kwargs...) # the real waiter
11+
signal
12+
end
13+
end
14+
15+
function Base.wait(signal::HSAStatusSignal)
16+
wait(signal.done)
17+
ex = signal.exception
18+
if ex !== nothing
19+
throw(ex)
20+
end
21+
end
22+
function _wait(signal::HSAStatusSignal; check_exceptions=true, cleanup=true, kwargs...)
23+
wait(signal.signal; kwargs...) # wait for completion signal
24+
unpreserve!(signal) # allow kernel-associated objects to be freed
25+
exe = signal.exe::HSAExecutable{Mem.Buffer}
26+
mod = EXE_TO_MODULE_MAP[exe].value
27+
agent = exe.agent
28+
ex = nothing
29+
signal_handle = (signal.signal.signal[]::HSA.Signal).handle
30+
if haskey(exe.globals, :__global_exception_flag)
31+
# Check if any wavefront for this kernel threw an exception
32+
ex_flag = get_global(exe, :__global_exception_flag)
33+
ex_flag_ptr = Base.unsafe_convert(Ptr{Int64}, ex_flag)
34+
ex_flag_value = Base.unsafe_load(ex_flag_ptr)
35+
if ex_flag_value != 0
36+
ex_strings = String[]
37+
if check_exceptions && haskey(exe.globals, :__global_exception_ring)
38+
# Check for and collect any exceptions, and clear their slots
39+
ex_ring = get_global(exe, :__global_exception_ring)
40+
ex_ring_ptr_ptr = Base.unsafe_convert(Ptr{Ptr{ExceptionEntry}}, ex_ring)
41+
ex_ring_ptr = unsafe_load(ex_ring_ptr_ptr)
42+
while (ex_ring_value = unsafe_load(ex_ring_ptr)).kern != 1
43+
if ex_ring_value.kern == signal_handle
44+
push!(ex_strings, unsafe_string(reinterpret(Ptr{UInt8}, ex_ring_value.ptr)))
45+
# FIXME: Write rest of entry first, then CAS 0 to kern field
46+
unsafe_store!(ex_ring_ptr, ExceptionEntry(UInt64(0), LLVMPtr{UInt8,1}(0)))
47+
end
48+
ex_ring_ptr += sizeof(ExceptionEntry)
49+
end
50+
end
51+
unique!(ex_strings)
52+
ex = KernelException(agent, isempty(ex_strings) ? nothing : join(ex_strings, '\n'))
53+
end
54+
end
55+
if cleanup
56+
# Clean-up malloc'd data
57+
for idx in length(mod.metadata):-1:1
58+
metadata_value = mod.metadata[idx]
59+
if metadata_value.kern == signal_handle
60+
@debug "Cleaning up data: $metadata_value"
61+
Mem.free(metadata_value.buf)
62+
deleteat!(mod.metadata, idx)
63+
end
64+
end
65+
end
66+
signal.exception = ex
67+
lock(RT_LOCK) do
68+
deleteat!(_active_kernels[signal.queue], findall(x->x==signal, _active_kernels[signal.queue]))
69+
end
70+
notify(signal.done)
71+
end
72+
barrier_and!(queue, signals::Vector{HSAStatusSignal}) = barrier_and!(queue, map(x->x.signal,signals))
73+
barrier_or!(queue, signals::Vector{HSAStatusSignal}) = barrier_or!(queue, map(x->x.signal,signals))

test/device/exceptions.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ catch err
1212
@test err isa AMDGPU.KernelException
1313
if err isa AMDGPU.KernelException
1414
@test err.exstr !== nothing
15-
@test occursin("BoundsError", err.exstr)
15+
@test occursin("julia_throw_boundserror", err.exstr)
1616
end
1717
end
1818

test/device/queries.jl

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
@testset "Active kernels" begin
2+
function kernel(sig)
3+
hostcall!(sig)
4+
nothing
5+
end
6+
7+
wait_ev = Base.Event()
8+
hc = HostCall(Nothing, Tuple{}) do
9+
wait(wait_ev)
10+
end
11+
12+
sig = @roc kernel(hc)
13+
@test sig.event in AMDGPU.active_kernels()
14+
@test !sig.event.done.set
15+
notify(wait_ev)
16+
wait(sig)
17+
@test !(sig.event in AMDGPU.active_kernels())
18+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ if AMDGPU.configured
5555
include("device/execution_control.jl")
5656
include("device/exceptions.jl")
5757
include("device/deps.jl")
58+
include("device/queries.jl")
5859
end
5960
@testset "ROCArray" begin
6061
@testset "GPUArrays test suite" begin

0 commit comments

Comments
 (0)