Skip to content

Commit f62a380

Browse files
authored
Specialize indexing triangular matrices with BandIndex (JuliaLang#55644)
With this, certain indexing operations involving a `BandIndex` may be evaluated as constants. This isn't used directly presently, but might allow for more performant broadcasting in the future. With this, ```julia julia> n = 3; T = Tridiagonal(rand(n-1), rand(n), rand(n-1)); julia> @code_warntype ((T,j) -> UpperTriangular(T)[LinearAlgebra.BandIndex(2,j)])(T, 1) MethodInstance for (::var"mmtk#17#18")(::Tridiagonal{Float64, Vector{Float64}}, ::Int64) from (::var"mmtk#17#18")(T, j) @ Main REPL[12]:1 Arguments #self#::Core.Const(var"mmtk#17#18"()) T::Tridiagonal{Float64, Vector{Float64}} j::Int64 Body::Float64 1 ─ %1 = Main.UpperTriangular(T)::UpperTriangular{Float64, Tridiagonal{Float64, Vector{Float64}}} │ %2 = LinearAlgebra.BandIndex::Core.Const(LinearAlgebra.BandIndex) │ %3 = (%2)(2, j)::Core.PartialStruct(LinearAlgebra.BandIndex, Any[Core.Const(2), Int64]) │ %4 = Base.getindex(%1, %3)::Core.Const(0.0) └── return %4 ``` The indexing operation may be evaluated at compile-time, as the band index is constant-propagated.
1 parent 9136bdd commit f62a380

File tree

4 files changed

+55
-4
lines changed

4 files changed

+55
-4
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,11 @@ end
166166
end
167167

168168
@inline function getindex(A::Bidiagonal{T}, b::BandIndex) where T
169-
@boundscheck checkbounds(A, _cartinds(b))
169+
@boundscheck checkbounds(A, b)
170170
if b.band == 0
171171
return @inbounds A.dv[b.index]
172-
elseif b.band == _offdiagind(A.uplo)
172+
elseif b.band (-1,1) && b.band == _offdiagind(A.uplo)
173+
# we explicitly compare the possible bands as b.band may be constant-propagated
173174
return @inbounds A.ev[b.index]
174175
else
175176
return bidiagzero(A, Tuple(_cartinds(b))...)

stdlib/LinearAlgebra/src/dense.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ norm2(x::Union{Array{T},StridedVector{T}}) where {T<:BlasFloat} =
110110
# Conservative assessment of types that have zero(T) defined for themselves
111111
haszero(::Type) = false
112112
haszero(::Type{T}) where {T<:Number} = isconcretetype(T)
113-
@propagate_inbounds _zero(M::AbstractArray{T}, i, j) where {T} = haszero(T) ? zero(T) : zero(M[i,j])
113+
@propagate_inbounds _zero(M::AbstractArray{T}, inds...) where {T} = haszero(T) ? zero(T) : zero(M[inds...])
114114

115115
"""
116116
triu!(M, k::Integer)

stdlib/LinearAlgebra/src/triangular.jl

+14
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,20 @@ Base.isstored(A::UpperTriangular, i::Int, j::Int) =
236236
@propagate_inbounds getindex(A::UpperTriangular, i::Int, j::Int) =
237237
i <= j ? A.data[i,j] : _zero(A.data,j,i)
238238

239+
# these specialized getindex methods enable constant-propagation of the band
240+
Base.@constprop :aggressive @propagate_inbounds function getindex(A::UnitLowerTriangular{T}, b::BandIndex) where {T}
241+
b.band < 0 ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
242+
end
243+
Base.@constprop :aggressive @propagate_inbounds function getindex(A::LowerTriangular, b::BandIndex)
244+
b.band <= 0 ? A.data[b] : _zero(A.data, b)
245+
end
246+
Base.@constprop :aggressive @propagate_inbounds function getindex(A::UnitUpperTriangular{T}, b::BandIndex) where {T}
247+
b.band > 0 ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
248+
end
249+
Base.@constprop :aggressive @propagate_inbounds function getindex(A::UpperTriangular, b::BandIndex)
250+
b.band >= 0 ? A.data[b] : _zero(A.data, b)
251+
end
252+
239253
_zero_triangular_half_str(::Type{<:UpperOrUnitUpperTriangular}) = "lower"
240254
_zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper"
241255

stdlib/LinearAlgebra/test/triangular.jl

+37-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ debug = false
66
using Test, LinearAlgebra, Random
77
using LinearAlgebra: BlasFloat, errorbounds, full!, transpose!,
88
UnitUpperTriangular, UnitLowerTriangular,
9-
mul!, rdiv!, rmul!, lmul!
9+
mul!, rdiv!, rmul!, lmul!, BandIndex
1010

1111
const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
1212

@@ -1286,4 +1286,40 @@ end
12861286
end
12871287
end
12881288

1289+
@testset "indexing with a BandIndex" begin
1290+
# these tests should succeed even if the linear index along
1291+
# the band isn't a constant, or type-inferred at all
1292+
M = rand(Int,2,2)
1293+
f(A,j, v::Val{n}) where {n} = Val(A[BandIndex(n,j)])
1294+
function common_tests(M, ind)
1295+
j = ind[]
1296+
@test @inferred(f(UpperTriangular(M), j, Val(-1))) == Val(0)
1297+
@test @inferred(f(UnitUpperTriangular(M), j, Val(-1))) == Val(0)
1298+
@test @inferred(f(UnitUpperTriangular(M), j, Val(0))) == Val(1)
1299+
@test @inferred(f(LowerTriangular(M), j, Val(1))) == Val(0)
1300+
@test @inferred(f(UnitLowerTriangular(M), j, Val(1))) == Val(0)
1301+
@test @inferred(f(UnitLowerTriangular(M), j, Val(0))) == Val(1)
1302+
end
1303+
common_tests(M, Any[1])
1304+
1305+
M = Diagonal([1,2])
1306+
common_tests(M, Any[1])
1307+
# extra tests for banded structure of the parent
1308+
for T in (UpperTriangular, UnitUpperTriangular)
1309+
@test @inferred(f(T(M), 1, Val(1))) == Val(0)
1310+
end
1311+
for T in (LowerTriangular, UnitLowerTriangular)
1312+
@test @inferred(f(T(M), 1, Val(-1))) == Val(0)
1313+
end
1314+
1315+
M = Tridiagonal([1,2], [1,2,3], [1,2])
1316+
common_tests(M, Any[1])
1317+
for T in (UpperTriangular, UnitUpperTriangular)
1318+
@test @inferred(f(T(M), 1, Val(2))) == Val(0)
1319+
end
1320+
for T in (LowerTriangular, UnitLowerTriangular)
1321+
@test @inferred(f(T(M), 1, Val(-2))) == Val(0)
1322+
end
1323+
end
1324+
12891325
end # module TestTriangular

0 commit comments

Comments
 (0)