Skip to content

Commit 9e13ee6

Browse files
authored
Use accurate softmax (#414)
1 parent 36bfc53 commit 9e13ee6

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

src/array.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ struct ROCArrayBackend <: AbstractGPUBackend end
88

99
struct ROCKernelContext <: AbstractKernelContext end
1010

11-
function GPUArrays.gpu_call(::ROCArrayBackend, f, args, threads::Int, blocks::Int;
12-
name::Union{String,Nothing})
13-
groupsize, gridsize = threads, blocks*threads
11+
function GPUArrays.gpu_call(::ROCArrayBackend, f, args, threads::Int, blocks::Int; name::Union{String,Nothing})
12+
groupsize, gridsize = threads, blocks * threads
1413
wait(@roc groupsize=groupsize gridsize=gridsize f(ROCKernelContext(), args...))
1514
end
1615
function GPUArrays.gpu_call(::ROCArrayBackend, f, args; elements::Int, name::Union{String,Nothing}=nothing)

src/dnn/softmax.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
softmax(x::T; dims) where T <: ROCArray = softmax!(similar(x), x; dims)
44

55
softmax!(y::T, x::T; dims) where T <: ROCArray =
6-
_softmax!(MIOPEN_SOFTMAX_FAST, y, x; dims)
6+
_softmax!(MIOPEN_SOFTMAX_ACCURATE, y, x; dims)
77

88
function ∇softmax(dy::T, y::T; dims) where T <: ROCArray
99
∇softmax!(similar(y), dy, y; dims)
1010
end
1111

1212
function ∇softmax!(dx::T, dy::T, y::T; dims) where T <: ROCArray
13-
_∇softmax!(MIOPEN_SOFTMAX_FAST, dx, dy, y; dims)
13+
_∇softmax!(MIOPEN_SOFTMAX_ACCURATE, dx, dy, y; dims)
1414
end
1515

1616
# Log-softmax.
@@ -53,8 +53,8 @@ function _softmax!(
5353
) where T <: ROCArray
5454
sdims = _softmax_dims(x; dims)
5555
if isnothing(sdims)
56-
return (algo == MIOPEN_SOFTMAX_FAST) ?
57-
_softmax!(y, x; dims) : _logsoftmax!(y, x; dims)
56+
return (algo == MIOPEN_SOFTMAX_LOG) ?
57+
_logsoftmax!(y, x; dims) : _softmax!(y, x; dims)
5858
end
5959

6060
AMDGPU.wait!((x, y))
@@ -71,8 +71,8 @@ function _∇softmax!(
7171
) where T <: ROCArray
7272
sdims = _softmax_dims(y; dims)
7373
if isnothing(sdims)
74-
return (algo == MIOPEN_SOFTMAX_FAST) ?
75-
_∇softmax!(dx, dy, y; dims) : _∇logsoftmax!(dx, dy, y; dims)
74+
return (algo == MIOPEN_SOFTMAX_LOG) ?
75+
_∇logsoftmax!(dx, dy, y; dims) : _∇softmax!(dx, dy, y; dims)
7676
end
7777

7878
AMDGPU.wait!((dx, dy, y))

0 commit comments

Comments
 (0)