Skip to content

Commit 1c1ec4d

Browse files
pxl-thwsmoses
andauthored
Initial Enzyme support (#668)
--------- Co-authored-by: William Moses <gh@wsmoses.com>
1 parent 0fece1f commit 1c1ec4d

File tree

8 files changed

+402
-14
lines changed

8 files changed

+402
-14
lines changed

.buildkite/pipeline.yml

+19
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,25 @@ steps:
8383
# JULIA_AMDGPU_HIP_MUST_LOAD: "1"
8484
# JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
8585

86+
- label: "Julia 1.10 Enzyme"
87+
plugins:
88+
- JuliaCI/julia#v1:
89+
version: "1.10"
90+
- JuliaCI/julia-test#v1:
91+
test_args: "enzyme"
92+
agents:
93+
queue: "juliagpu"
94+
rocm: "*"
95+
rocmgpu: "*"
96+
if: build.message !~ /\[skip tests\]/
97+
command: "julia --project -e 'using Pkg; Pkg.update()'"
98+
timeout_in_minutes: 180
99+
env:
100+
JULIA_NUM_THREADS: 4
101+
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
102+
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
103+
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
104+
86105
- label: "GPU-less environment"
87106
plugins:
88107
- JuliaCI/julia#v1:

Project.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,20 @@ UnsafeAtomics = "013be700-e6cd-48c3-b4a1-df204f14c38f"
3434

3535
[weakdeps]
3636
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
37+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3738

3839
[extensions]
3940
AMDGPUChainRulesCoreExt = "ChainRulesCore"
41+
AMDGPUEnzymeCoreExt = "EnzymeCore"
4042

4143
[compat]
4244
AbstractFFTs = "1.0"
4345
AcceleratedKernels = "0.2"
4446
Adapt = "4"
45-
Atomix = "0.1, 1"
47+
Atomix = "1"
4648
CEnum = "0.4, 0.5"
4749
ChainRulesCore = "1"
50+
EnzymeCore = "0.8"
4851
ExprTools = "0.1"
4952
GPUArrays = "11.2"
5053
GPUCompiler = "0.27, 1.0"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
module AMDGPUEnzymeCoreExt
2+
3+
using AMDGPU
4+
using EnzymeCore
5+
using EnzymeCore: EnzymeRules
6+
using GPUCompiler
7+
8+
include("meta_kernels.jl")
9+
10+
function EnzymeCore.compiler_job_from_backend(
11+
::ROCBackend, @nospecialize(F::Type), @nospecialize(TT::Type),
12+
)
13+
mi = GPUCompiler.methodinstance(F, TT)
14+
return GPUCompiler.CompilerJob(mi, AMDGPU.compiler_config(AMDGPU.device()))
15+
end
16+
17+
function EnzymeRules.forward(
18+
config, fn::Const{typeof(AMDGPU.hipfunction)}, ::Type{<: Duplicated},
19+
f::Const{F}, tt::Const{TT}; kwargs...,
20+
) where {F, TT}
21+
res = fn.val(f.val, tt.val; kwargs...)
22+
return Duplicated(res, res)
23+
end
24+
25+
function EnzymeRules.forward(
26+
config, fn::Const{typeof(AMDGPU.hipfunction)}, ::Type{<: BatchDuplicated{T, N}},
27+
f::Const{F}, tt::Const{TT}; kwargs...,
28+
) where {F, TT, T, N}
29+
res = fn.val(f.val, tt.val; kwargs...)
30+
return BatchDuplicated(res, ntuple(_ -> res, Val(N)))
31+
end
32+
33+
function EnzymeRules.reverse(
34+
config, fn::Const{typeof(AMDGPU.hipfunction)}, ::Type{RT},
35+
subtape, f, tt; kwargs...,
36+
) where RT
37+
return (nothing, nothing)
38+
end
39+
40+
function EnzymeRules.forward(
41+
config, fn::Const{typeof(AMDGPU.rocconvert)}, ::Type{RT}, x::IT,
42+
) where {RT, IT}
43+
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
44+
config_width = EnzymeRules.width(config)
45+
if config_width == 1
46+
Duplicated(fn.val(x.val), fn.val(x.dval))
47+
else
48+
tup = ntuple(Val(config_width)) do i
49+
Base.@_inline_meta
50+
fn.val(x.dval[i])::eltype(RT)
51+
end
52+
BatchDuplicated(fn.val(x.val), tup)
53+
end
54+
55+
elseif EnzymeRules.needs_shadow(config)
56+
config_width = EnzymeRules.width(config)
57+
ST = EnzymeCore.shadow_type(config, RT)
58+
if config_width == 1
59+
fn.val(x.dval)::ST
60+
else
61+
(ntuple(Val(config_width)) do i
62+
Base.@_inline_meta
63+
fn.val(x.dval[i])::eltype(RT)
64+
end)::ST
65+
end
66+
67+
elseif EnzymeRules.needs_primal(config)
68+
fn.val(x.val)::eltype(RT)
69+
else
70+
nothing
71+
end
72+
end
73+
74+
function EnzymeRules.augmented_primal(
75+
config, fn::Const{typeof(AMDGPU.rocconvert)}, ::Type{RT}, x::IT,
76+
) where {RT, IT}
77+
primal = EnzymeRules.needs_primal(config) ?
78+
fn.val(x.val) : nothing
79+
80+
shadow = if EnzymeRules.needs_shadow(config)
81+
config_width = EnzymeRules.width(config)
82+
if config_width == 1
83+
fn.val(x.dval)
84+
else
85+
ntuple(Val(config_width)) do i
86+
Base.@_inline_meta
87+
fn.val(x.dval[i])
88+
end
89+
end
90+
else
91+
nothing
92+
end
93+
94+
return EnzymeRules.AugmentedReturn{
95+
EnzymeRules.primal_type(config, RT),
96+
EnzymeRules.shadow_type(config, RT), Nothing
97+
}(primal, shadow, nothing)
98+
end
99+
100+
function EnzymeRules.reverse(
101+
config, fn::Const{typeof(AMDGPU.rocconvert)}, ::Type{RT}, tape, x::IT,
102+
) where {RT, IT}
103+
return (nothing,)
104+
end
105+
106+
function EnzymeRules.forward(
107+
config, fn::EnzymeCore.Annotation{AMDGPU.Runtime.HIPKernel{F, TT}},
108+
::Type{Const{Nothing}}, args...; kwargs...,
109+
) where {F, TT}
110+
GC.@preserve args begin
111+
kernel_args = ((rocconvert(a) for a in args)...,)
112+
kernel_tt = Tuple{(typeof(config), F, (typeof(a) for a in kernel_args)...)...}
113+
kernel = AMDGPU.hipfunction(meta_fn, kernel_tt)
114+
kernel(config, fn.val.f, kernel_args...; kwargs...)
115+
end
116+
return
117+
end
118+
119+
function EnzymeRules.reverse(
120+
config, ofn::EnzymeCore.Annotation{AMDGPU.Runtime.HIPKernel{F, TT}},
121+
::Type{Const{Nothing}}, subtape, args...;
122+
groupsize::AMDGPU.Runtime.ROCDim = 1,
123+
gridsize::AMDGPU.Runtime.ROCDim = 1,
124+
kwargs...,
125+
) where {F, TT}
126+
kernel_args = ((rocconvert(a) for a in args)...,)
127+
kernel_tt = map(typeof, kernel_args)
128+
129+
ModifiedBetween = EnzymeRules.overwritten(config)
130+
TapeType = EnzymeCore.tape_type(
131+
ReverseSplitModified(
132+
EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config),
133+
Val(ModifiedBetween)),
134+
Const{F},
135+
Const{Nothing},
136+
kernel_tt...,
137+
)
138+
groupsize = AMDGPU.Runtime.ROCDim3(groupsize)
139+
gridsize = AMDGPU.Runtime.ROCDim3(gridsize)
140+
141+
GC.@preserve args subtape begin
142+
subtape_cc = rocconvert(subtape)
143+
kernel_tt2 = Tuple{
144+
(typeof(config), F, typeof(subtape_cc), kernel_tt...)...}
145+
kernel = AMDGPU.hipfunction(meta_revf, kernel_tt2)
146+
kernel(config, ofn.val.f, subtape_cc, kernel_args...;
147+
groupsize, gridsize, kwargs...)
148+
end
149+
150+
return ntuple(Val(length(kernel_args))) do i
151+
Base.@_inline_meta
152+
nothing
153+
end
154+
end
155+
156+
function EnzymeRules.augmented_primal(
157+
config, fn::Const{typeof(AMDGPU.hipfunction)},
158+
::Type{RT}, f::Const{F}, tt::Const{TT}; kwargs...
159+
) where {F, CT, RT <: EnzymeCore.Annotation{CT}, TT}
160+
res = fn.val(f.val, tt.val; kwargs...)
161+
primal = EnzymeRules.needs_primal(config) ? res : nothing
162+
163+
shadow = if EnzymeRules.needs_shadow(config)
164+
config_width = EnzymeRules.width(config)
165+
config_width == 1 ?
166+
res :
167+
ntuple(Val(config_width)) do i
168+
Base.@_inline_meta
169+
res
170+
end
171+
else
172+
nothing
173+
end
174+
175+
return EnzymeRules.AugmentedReturn{
176+
EnzymeRules.primal_type(config, RT),
177+
EnzymeRules.shadow_type(config, RT), Nothing,
178+
}(primal, shadow, nothing)
179+
end
180+
181+
function EnzymeRules.augmented_primal(
182+
config, fn::EnzymeCore.Annotation{AMDGPU.Runtime.HIPKernel{F,TT}},
183+
::Type{Const{Nothing}}, args...;
184+
groupsize::AMDGPU.Runtime.ROCDim = 1,
185+
gridsize::AMDGPU.Runtime.ROCDim = 1, kwargs...,
186+
) where {F,TT}
187+
kernel_args = ((rocconvert(a) for a in args)...,)
188+
kernel_tt = map(typeof, kernel_args)
189+
190+
ModifiedBetween = EnzymeRules.overwritten(config)
191+
compiler_job = EnzymeCore.compiler_job_from_backend(
192+
ROCBackend(), typeof(Base.identity), Tuple{Float64})
193+
TapeType = EnzymeCore.tape_type(
194+
compiler_job,
195+
ReverseSplitModified(
196+
EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config),
197+
Val(ModifiedBetween)),
198+
Const{F}, Const{Nothing},
199+
kernel_tt...,
200+
)
201+
groupsize = AMDGPU.Runtime.ROCDim3(groupsize)
202+
gridsize = AMDGPU.Runtime.ROCDim3(gridsize)
203+
subtape = ROCArray{TapeType}(undef,
204+
gridsize.x * gridsize.y * gridsize.z *
205+
groupsize.x * groupsize.y * groupsize.z)
206+
207+
GC.@preserve args subtape begin
208+
subtape_cc = rocconvert(subtape)
209+
kernel_tt2 = Tuple{
210+
(typeof(config), F, typeof(subtape_cc), kernel_tt...)...}
211+
kernel = AMDGPU.hipfunction(meta_augf, kernel_tt2)
212+
kernel(config, fn.val.f, subtape_cc, kernel_args...;
213+
groupsize, gridsize, kwargs...)
214+
end
215+
return EnzymeRules.AugmentedReturn{Nothing, Nothing, ROCArray}(nothing, nothing, subtape)
216+
end
217+
218+
end
+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
function meta_fn(config, fn, args::Vararg{Any, N}) where N
2+
EnzymeCore.autodiff_deferred(
3+
EnzymeCore.set_runtime_activity(Forward, config),
4+
Const(fn), Const, args...)
5+
return
6+
end
7+
8+
9+
function meta_augf(
10+
config, f, tape::ROCDeviceArray{TapeType}, args::Vararg{Any, N},
11+
) where {N, TapeType}
12+
ModifiedBetween = EnzymeRules.overwritten(config)
13+
forward, _ = EnzymeCore.autodiff_deferred_thunk(
14+
ReverseSplitModified(
15+
EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config),
16+
Val(ModifiedBetween)),
17+
TapeType,
18+
Const{Core.Typeof(f)},
19+
Const{Nothing},
20+
map(typeof, args)...,
21+
)
22+
23+
idx = 0
24+
# idx *= gridDim().x
25+
idx += workgroupIdx().x - 1
26+
27+
idx *= gridGroupDim().y
28+
idx += workgroupIdx().y - 1
29+
30+
idx *= gridGroupDim().z
31+
idx += workgroupIdx().z - 1
32+
33+
idx *= workgroupDim().x
34+
idx += workitemIdx().x - 1
35+
36+
idx *= workgroupDim().y
37+
idx += workitemIdx().y - 1
38+
39+
idx *= workgroupDim().z
40+
idx += workitemIdx().z - 1
41+
idx += 1
42+
43+
@inbounds tape[idx] = forward(Const(f), args...)[1]
44+
return
45+
end
46+
47+
function meta_revf(
48+
config, f, tape::ROCDeviceArray{TapeType}, args::Vararg{Any, N},
49+
) where {N, TapeType}
50+
ModifiedBetween = EnzymeRules.overwritten(config)
51+
_, reverse = EnzymeCore.autodiff_deferred_thunk(
52+
ReverseSplitModified(
53+
EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config),
54+
Val(ModifiedBetween)),
55+
TapeType,
56+
Const{Core.Typeof(f)},
57+
Const{Nothing},
58+
map(typeof, args)...,
59+
)
60+
61+
idx = 0
62+
# idx *= gridDim().x
63+
idx += workgroupIdx().x - 1
64+
65+
idx *= gridGroupDim().y
66+
idx += workgroupIdx().y - 1
67+
68+
idx *= gridGroupDim().z
69+
idx += workgroupIdx().z - 1
70+
71+
idx *= workgroupDim().x
72+
idx += workitemIdx().x - 1
73+
74+
idx *= workgroupDim().y
75+
idx += workitemIdx().y - 1
76+
77+
idx *= workgroupDim().z
78+
idx += workitemIdx().z - 1
79+
idx += 1
80+
81+
reverse(Const(f), args..., @inbounds tape[idx])
82+
return
83+
end

src/AMDGPU.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ using .ROCmDiscovery
6565

6666
include("utils.jl")
6767

68-
include(joinpath("hsa", "HSA.jl"))
69-
include(joinpath("hip", "HIP.jl"))
68+
include("hsa/HSA.jl")
69+
include("hip/HIP.jl")
7070

7171
using .HIP
7272
using .HIP: HIPContext, HIPDevice, HIPStream
@@ -101,7 +101,7 @@ export sync_workgroup, sync_workgroup_count, sync_workgroup_and, sync_workgroup_
101101

102102
include("compiler/Compiler.jl")
103103
import .Compiler
104-
import .Compiler: hipfunction
104+
import .Compiler: hipfunction, compiler_config
105105

106106
include("tls.jl")
107107
include("highlevel.jl")
@@ -117,12 +117,12 @@ include("kernels/accumulate.jl")
117117
include("kernels/sorting.jl")
118118
include("kernels/reverse.jl")
119119

120-
include(joinpath("blas", "rocBLAS.jl"))
121-
include(joinpath("solver", "rocSOLVER.jl"))
122-
include(joinpath("sparse", "rocSPARSE.jl"))
123-
include(joinpath("rand", "rocRAND.jl"))
124-
include(joinpath("fft", "rocFFT.jl"))
125-
include(joinpath("dnn", "MIOpen.jl"))
120+
include("blas/rocBLAS.jl")
121+
include("solver/rocSOLVER.jl")
122+
include("sparse/rocSPARSE.jl")
123+
include("rand/rocRAND.jl")
124+
include("fft/rocFFT.jl")
125+
include("dnn/MIOpen.jl")
126126

127127
include("random.jl")
128128

0 commit comments

Comments
 (0)