Skip to content

Commit a52cbbe

Browse files
authored
Add findnz & fix broadcasting (#704)
1 parent 57d8324 commit a52cbbe

File tree

5 files changed

+63
-2
lines changed

5 files changed

+63
-2
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AMDGPU"
22
uuid = "21141c5a-9bdb-4563-92ae-f87d6854732e"
33
authors = ["Julian P Samaroo <jpsamaroo@jpsamaroo.me>", "Valentin Churavy <v.churavy@gmail.com>", "Anton Smirnov <tonysmn97@gmail.com>"]
4-
version = "1.1.2"
4+
version = "1.1.3"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/broadcast.jl

+12-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,18 @@ BroadcastStyle(::Type{<:ROCArray{T, N, B}}) where {T, N, B} =
99
BroadcastStyle(W::Type{<:AnyROCArray{T, N}}) where {T, N} =
1010
ROCArrayStyle{N, buftype(Adapt.unwrap_type(W))}()
1111

12-
# TODO handle broadcast of different buffer types (use unified memory).
12+
# TODO use unified buffer once we support it.
13+
# Broadcast of two different buffers - choose `HIPBuffer`.
14+
BroadcastStyle(
15+
::ROCArrayStyle{N1, B1},
16+
::ROCArrayStyle{N2, B2},
17+
) where {N1,N2,B1,B2} = ROCArrayStyle{max(N1,N2), Mem.HIPBuffer}()
18+
19+
# Different N, same buffer type.
20+
BroadcastStyle(
21+
::ROCArrayStyle{N1, B},
22+
::ROCArrayStyle{N2, B},
23+
) where {N1,N2,B} = ROCArrayStyle{max(N1,N2), B}()
1324

1425
# Allocation of output arrays.
1526
function Base.similar(

src/sparse/array.jl

+19
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,31 @@ Base.eltype(g::ROCSparseMatrix{T}) where T = T
262262

263263
## sparse array interface
264264

265+
SparseArrays.sparsevec(I::ROCArray{Ti}, V::ROCArray{Tv}, n::Integer) where {Ti,Tv} =
266+
ROCSparseVector(I, V, n)
267+
268+
function SparseArrays.findnz(S::T) where {T <: AbstractROCSparseMatrix}
269+
S2 = ROCSparseMatrixCOO(S)
270+
I = S2.rowInd
271+
J = S2.colInd
272+
V = S2.nzVal
273+
274+
# To make it compatible with the SparseArrays.jl version
275+
idxs = sortperm(J)
276+
I = I[idxs]
277+
J = J[idxs]
278+
V = V[idxs]
279+
280+
return (I, J, V)
281+
end
282+
265283
SparseArrays.nnz(g::AbstractROCSparseArray) = g.nnz
266284
SparseArrays.nonzeros(g::AbstractROCSparseArray) = g.nzVal
267285

268286
SparseArrays.nonzeroinds(g::AbstractROCSparseVector) = g.iPtr
269287

270288
SparseArrays.rowvals(g::ROCSparseMatrixCSC) = g.rowVal
289+
SparseArrays.getcolptr(g::ROCSparseMatrixCSC) = g.colPtr
271290

272291
LinearAlgebra.issymmetric(M::Union{ROCSparseMatrixCSC,ROCSparseMatrixCSR}) = false
273292
LinearAlgebra.ishermitian(M::Union{ROCSparseMatrixCSC,ROCSparseMatrixCSR}) = false

test/rocarray/base.jl

+8
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,14 @@ end
145145
AMDGPU.unsafe_free!(xd2)
146146
@test_throws ArgumentError pointer(xd2)
147147
end
148+
149+
@testset "Broadcasting different buffer types" begin
150+
x = rand(Float32, 4, 16, 16)
151+
xd = unsafe_wrap(ROCArray, pointer(x), size(x))
152+
y = AMDGPU.zeros(Float32, 3, 16, 16)
153+
y .= @view(xd[1:3, :, :])
154+
@test Array(y) @view(x[1:3, :, :])
155+
end
148156
end
149157

150158
@testset "unsafe_free" begin

test/rocsparse/interfaces.jl

+23
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,26 @@
203203
@test (dD - dA) isa typ
204204
end
205205
end
206+
207+
@testset "SparseArrays.jl" begin
208+
@testset "findnz" begin
209+
n = 35
210+
A = sprand(n, n, 0.2)
211+
d_A = ROCSparseMatrixCSC(A)
212+
@test Array(SparseArrays.getcolptr(d_A)) == SparseArrays.getcolptr(A)
213+
214+
i, j, v = findnz(A)
215+
d_i, d_j, d_v = findnz(d_A)
216+
@test Array(d_i) == i && Array(d_j) == j && Array(d_v) == v
217+
218+
i = unique(sort(rand(1:n, 10)))
219+
vals = rand(length(i))
220+
d_i = ROCArray(i)
221+
d_vals = ROCArray(vals)
222+
v = sparsevec(i, vals, n)
223+
d_v = sparsevec(d_i, d_vals, n)
224+
@test Array(d_v.iPtr) == v.nzind
225+
@test Array(d_v.nzVal) == v.nzval
226+
@test d_v.len == v.n
227+
end
228+
end

0 commit comments

Comments
 (0)