Skip to content

Commit 8d577ab

Browse files
authored
zeros/ones/fill may accept arbitrary axes that are supported by similar (JuliaLang#53965)
The idea is that functions like `zeros` are essentially constructing a container and filling it with a value. `similar` seems perfectly placed to construct such a container, so we may accept arbitrary axes in `zeros` as long as there's a corresponding `similar` method that is defined for the axes. Packages therefore would only need to define `similar`, and would get `zeros`/`ones` and `fill` for free. For example, the following will work after this: ```julia julia> using StaticArrays julia> zeros(SOneTo(2), 2) 2×2 Matrix{Float64}: 0.0 0.0 0.0 0.0 julia> zeros(SOneTo(2), Base.OneTo(2)) 2×2 Matrix{Float64}: 0.0 0.0 0.0 0.0 ``` Neither of these work on the current master, as `StaticArrays` doesn't define `zeros` for these combinations, even though it does define `similar`. One may argue for these methods to be added to `StaticArrays`, but this seems to be adding redundancy. The flip side is that `OffsetArrays` defines exactly these methods, so adding them to `Base` would break precompilation for the package. However, `OffsetArrays` really shouldn't be defining these methods, as this is type-piracy. The methods may be version-limited in `OffsetArrays` if this PR is merged. On the face of it, `trues` and `falses` should also work similarly, but currently these seem to be bypassing `similar` and constructing a `BitArray` explicitly. I have not added the corresponding methods for these functions, but they may be added as well.
1 parent b9aeafa commit 8d577ab

File tree

5 files changed

+67
-0
lines changed

5 files changed

+67
-0
lines changed

base/array.jl

+6
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ function fill end
529529
fill(v, dims::DimOrInd...) = fill(v, dims)
530530
fill(v, dims::NTuple{N, Union{Integer, OneTo}}) where {N} = fill(v, map(to_dim, dims))
531531
fill(v, dims::NTuple{N, Integer}) where {N} = (a=Array{typeof(v),N}(undef, dims); fill!(a, v); a)
532+
fill(v, dims::NTuple{N, DimOrInd}) where {N} = (a=similar(Array{typeof(v),N}, dims); fill!(a, v); a)
532533
fill(v, dims::Tuple{}) = (a=Array{typeof(v),0}(undef, dims); fill!(a, v); a)
533534

534535
"""
@@ -589,6 +590,11 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one))
589590
fill!(a, $felt(T))
590591
return a
591592
end
593+
function $fname(::Type{T}, dims::NTuple{N, DimOrInd}) where {T,N}
594+
a = similar(Array{T,N}, dims)
595+
fill!(a, $felt(T))
596+
return a
597+
end
592598
end
593599
end
594600

base/bitarray.jl

+2
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ falses(dims::DimOrInd...) = falses(dims)
404404
falses(dims::NTuple{N, Union{Integer, OneTo}}) where {N} = falses(map(to_dim, dims))
405405
falses(dims::NTuple{N, Integer}) where {N} = fill!(BitArray(undef, dims), false)
406406
falses(dims::Tuple{}) = fill!(BitArray(undef, dims), false)
407+
falses(dims::NTuple{N, DimOrInd}) where {N} = fill!(similar(BitArray, dims), false)
407408

408409
"""
409410
trues(dims)
@@ -422,6 +423,7 @@ trues(dims::DimOrInd...) = trues(dims)
422423
trues(dims::NTuple{N, Union{Integer, OneTo}}) where {N} = trues(map(to_dim, dims))
423424
trues(dims::NTuple{N, Integer}) where {N} = fill!(BitArray(undef, dims), true)
424425
trues(dims::Tuple{}) = fill!(BitArray(undef, dims), true)
426+
trues(dims::NTuple{N, DimOrInd}) where {N} = fill!(similar(BitArray, dims), true)
425427

426428
function one(x::BitMatrix)
427429
m, n = size(x)

test/abstractarray.jl

+22
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ using .Main.StructArrays
1111
isdefined(Main, :FillArrays) || @eval Main include("testhelpers/FillArrays.jl")
1212
using .Main.FillArrays
1313

14+
isdefined(Main, :SizedArrays) || @eval Main include("testhelpers/SizedArrays.jl")
15+
using .Main.SizedArrays
16+
1417
A = rand(5,4,3)
1518
@testset "Bounds checking" begin
1619
@test checkbounds(Bool, A, 1, 1, 1) == true
@@ -2097,3 +2100,22 @@ end
20972100
@test r2[i] == z[j]
20982101
end
20992102
end
2103+
2104+
@testset "zero for arbitrary axes" begin
2105+
r = SizedArrays.SOneTo(2)
2106+
s = Base.OneTo(2)
2107+
_to_oneto(x::Integer) = Base.OneTo(2)
2108+
_to_oneto(x::Union{Base.OneTo, SizedArrays.SOneTo}) = x
2109+
for (f, v) in ((zeros, 0), (ones, 1), ((x...)->fill(3,x...),3))
2110+
for ax in ((r,r), (s, r), (2, r))
2111+
A = f(ax...)
2112+
@test axes(A) == map(_to_oneto, ax)
2113+
if all(x -> x isa SizedArrays.SOneTo, ax)
2114+
@test A isa SizedArrays.SizedArray && parent(A) isa Array
2115+
else
2116+
@test A isa Array
2117+
end
2118+
@test all(==(v), A)
2119+
end
2120+
end
2121+
end

test/bitarray.jl

+22
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
using Base: findprevnot, findnextnot
44
using Random, LinearAlgebra, Test
55

6+
isdefined(Main, :SizedArrays) || @eval Main include("testhelpers/SizedArrays.jl")
7+
using .Main.SizedArrays
8+
69
tc(r1::NTuple{N,Any}, r2::NTuple{N,Any}) where {N} = all(x->tc(x...), [zip(r1,r2)...])
710
tc(r1::BitArray{N}, r2::Union{BitArray{N},Array{Bool,N}}) where {N} = true
811
tc(r1::SubArray{Bool,N1,BitArray{N2}}, r2::SubArray{Bool,N1,<:Union{BitArray{N2},Array{Bool,N2}}}) where {N1,N2} = true
@@ -82,6 +85,25 @@ allsizes = [((), BitArray{0}), ((v1,), BitVector),
8285
@test !isassigned(b, length(b) + 1)
8386
end
8487

88+
@testset "trues and falses with custom axes" begin
89+
for ax in ((SizedArrays.SOneTo(2),), (SizedArrays.SOneTo(2), Base.OneTo(2)))
90+
t = trues(ax)
91+
if all(x -> x isa SizedArrays.SOneTo, ax)
92+
@test t isa SizedArrays.SizedArray && parent(t) isa BitArray
93+
else
94+
@test t isa BitArray
95+
end
96+
@test all(t)
97+
98+
f = falses(ax)
99+
if all(x -> x isa SizedArrays.SOneTo, ax)
100+
@test t isa SizedArrays.SizedArray && parent(t) isa BitArray
101+
else
102+
@test t isa BitArray
103+
end
104+
@test !any(f)
105+
end
106+
end
85107

86108
@testset "Conversions for size $sz" for (sz, T) in allsizes
87109
b1 = rand!(falses(sz...))

test/testhelpers/SizedArrays.jl

+15
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,25 @@ Base.size(a::SizedArray) = size(typeof(a))
4343
Base.size(::Type{<:SizedArray{SZ}}) where {SZ} = SZ
4444
Base.axes(a::SizedArray) = map(SOneTo, size(a))
4545
Base.getindex(A::SizedArray, i...) = getindex(A.data, i...)
46+
Base.setindex!(A::SizedArray, v, i...) = setindex!(A.data, v, i...)
4647
Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T)))
48+
Base.parent(S::SizedArray) = S.data
4749
+(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data)
4850
==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data
4951

52+
homogenize_shape(t::Tuple) = (_homogenize_shape(first(t)), homogenize_shape(Base.tail(t))...)
53+
homogenize_shape(::Tuple{}) = ()
54+
_homogenize_shape(x::Integer) = x
55+
_homogenize_shape(x::AbstractUnitRange) = length(x)
56+
const Dims = Union{Integer, Base.OneTo, SOneTo}
57+
function Base.similar(::Type{A}, shape::Tuple{Dims, Vararg{Dims}}) where {A<:AbstractArray}
58+
similar(A, homogenize_shape(shape))
59+
end
60+
function Base.similar(::Type{A}, shape::Tuple{SOneTo, Vararg{SOneTo}}) where {A<:AbstractArray}
61+
R = similar(A, length.(shape))
62+
SizedArray{length.(shape)}(R)
63+
end
64+
5065
const SizedMatrixLike = Union{SizedMatrix, Transpose{<:Any, <:SizedMatrix}, Adjoint{<:Any, <:SizedMatrix}}
5166

5267
_data(S::SizedArray) = S.data

0 commit comments

Comments
 (0)