@@ -19,63 +19,8 @@ default_isa(device::RuntimeDevice{HSAAgent}) =
19
19
struct RuntimeEvent{E}
20
20
event:: E
21
21
end
22
- create_event (exe) = RuntimeEvent (create_event (RUNTIME[], exe))
23
- Base. wait (event:: RuntimeEvent , exe) = wait (event. event, exe)
24
-
25
- " Tracks the completion and status of a kernel's execution."
26
- struct HSAStatusSignal
27
- signal:: HSASignal
28
- exe:: HSAExecutable
29
- end
30
- create_event (:: typeof (HSA_rt), exe) = HSAStatusSignal (HSASignal (), exe. exe)
31
- Base. wait (event:: HSAStatusSignal ; kwargs... ) = wait (RuntimeEvent (event); kwargs... )
32
- function Base. wait (event:: RuntimeEvent{HSAStatusSignal} ; check_exceptions= true , cleanup= true , kwargs... )
33
- wait (event. event. signal; kwargs... ) # wait for completion signal
34
- unpreserve! (event) # allow kernel-associated objects to be freed
35
- exe = event. event. exe:: HSAExecutable{Mem.Buffer}
36
- mod = EXE_TO_MODULE_MAP[exe]. value
37
- agent = exe. agent
38
- ex = nothing
39
- signal_handle = (event. event. signal. signal[]:: HSA.Signal ). handle
40
- if haskey (exe. globals, :__global_exception_flag )
41
- # Check if any wavefront for this kernel threw an exception
42
- ex_flag = get_global (exe, :__global_exception_flag )
43
- ex_flag_ptr = Base. unsafe_convert (Ptr{Int64}, ex_flag)
44
- ex_flag_value = Base. unsafe_load (ex_flag_ptr)
45
- if ex_flag_value != 0
46
- ex_strings = String[]
47
- if check_exceptions && haskey (exe. globals, :__global_exception_ring )
48
- # Check for and collect any exceptions, and clear their slots
49
- ex_ring = get_global (exe, :__global_exception_ring )
50
- ex_ring_ptr_ptr = Base. unsafe_convert (Ptr{Ptr{ExceptionEntry}}, ex_ring)
51
- ex_ring_ptr = unsafe_load (ex_ring_ptr_ptr)
52
- while (ex_ring_value = unsafe_load (ex_ring_ptr)). kern != 1
53
- if ex_ring_value. kern == signal_handle
54
- push! (ex_strings, unsafe_string (reinterpret (Ptr{UInt8}, ex_ring_value. ptr)))
55
- # FIXME : Write rest of entry first, then CAS 0 to kern field
56
- unsafe_store! (ex_ring_ptr, ExceptionEntry (UInt64 (0 ), LLVMPtr {UInt8,1} (0 )))
57
- end
58
- ex_ring_ptr += sizeof (ExceptionEntry)
59
- end
60
- end
61
- unique! (ex_strings)
62
- ex = KernelException (RuntimeDevice (agent), isempty (ex_strings) ? nothing : join (ex_strings, ' \n ' ))
63
- end
64
- end
65
- if cleanup
66
- # Clean-up malloc'd data
67
- for idx in length (mod. metadata): - 1 : 1
68
- metadata_value = mod. metadata[idx]
69
- if metadata_value. kern == signal_handle
70
- @debug " Cleaning up data: $metadata_value "
71
- Mem. free (metadata_value. buf)
72
- deleteat! (mod. metadata, idx)
73
- end
74
- end
75
- end
76
- ex != = nothing && throw (ex)
77
- end
78
-
22
+ create_event (exe; kwargs... ) = RuntimeEvent (create_event (RUNTIME[], exe; kwargs... ))
23
+ Base. wait (event:: RuntimeEvent ) = wait (event. event)
79
24
80
25
struct RuntimeExecutable{E}
81
26
exe:: E
100
45
get_global (exe:: RuntimeExecutable , sym:: Symbol ) =
101
46
get_global (exe. exe, sym)
102
47
48
+ create_event (:: typeof (HSA_rt), exe:: RuntimeExecutable{<:HSAExecutable} ; queue= default_queue (), kwargs... ) =
49
+ HSAStatusSignal (HSASignal (), exe. exe, queue. queue)
50
+
103
51
struct RuntimeKernel{K}
104
52
kernel:: K
105
53
end
@@ -116,6 +64,4 @@ function launch_kernel(::typeof(HSA_rt), queue, kern, event;
116
64
workgroup_size= groupsize, grid_size= gridsize)
117
65
end
118
66
barrier_and! (queue, events:: Vector{<:RuntimeEvent} ) = barrier_and! (queue, map (x-> x. event,events))
119
- barrier_and! (queue, signals:: Vector{HSAStatusSignal} ) = barrier_and! (queue, map (x-> x. signal,signals))
120
67
barrier_or! (queue, events:: Vector{<:RuntimeEvent} ) = barrier_or! (queue, map (x-> x. event,events))
121
- barrier_or! (queue, signals:: Vector{HSAStatusSignal} ) = barrier_or! (queue, map (x-> x. signal,signals))
0 commit comments