3
3
softmax (x:: T ; dims) where T <: ROCArray = softmax! (similar (x), x; dims)
4
4
5
5
softmax! (y:: T , x:: T ; dims) where T <: ROCArray =
6
- _softmax! (MIOPEN_SOFTMAX_FAST , y, x; dims)
6
+ _softmax! (MIOPEN_SOFTMAX_ACCURATE , y, x; dims)
7
7
8
8
function ∇softmax (dy:: T , y:: T ; dims) where T <: ROCArray
9
9
∇softmax! (similar (y), dy, y; dims)
10
10
end
11
11
12
12
function ∇softmax! (dx:: T , dy:: T , y:: T ; dims) where T <: ROCArray
13
- _∇softmax! (MIOPEN_SOFTMAX_FAST , dx, dy, y; dims)
13
+ _∇softmax! (MIOPEN_SOFTMAX_ACCURATE , dx, dy, y; dims)
14
14
end
15
15
16
16
# Log-softmax.
@@ -53,8 +53,8 @@ function _softmax!(
53
53
) where T <: ROCArray
54
54
sdims = _softmax_dims (x; dims)
55
55
if isnothing (sdims)
56
- return (algo == MIOPEN_SOFTMAX_FAST ) ?
57
- _softmax ! (y, x; dims) : _logsoftmax ! (y, x; dims)
56
+ return (algo == MIOPEN_SOFTMAX_LOG ) ?
57
+ _logsoftmax ! (y, x; dims) : _softmax ! (y, x; dims)
58
58
end
59
59
60
60
AMDGPU. wait! ((x, y))
@@ -71,8 +71,8 @@ function _∇softmax!(
71
71
) where T <: ROCArray
72
72
sdims = _softmax_dims (y; dims)
73
73
if isnothing (sdims)
74
- return (algo == MIOPEN_SOFTMAX_FAST ) ?
75
- _∇softmax ! (dx, dy, y; dims) : _∇logsoftmax ! (dx, dy, y; dims)
74
+ return (algo == MIOPEN_SOFTMAX_LOG ) ?
75
+ _∇logsoftmax ! (dx, dy, y; dims) : _∇softmax ! (dx, dy, y; dims)
76
76
end
77
77
78
78
AMDGPU. wait! ((dx, dy, y))
0 commit comments