Skip to content

Commit e0656ac

Browse files
authored
Add sort for NTuples (JuliaLang#54494)
This is partially a reland of JuliaLang#46104, but without the controversial `sort(x) = sort!(copymutable(x))` and with some extensibility improvements. Implements JuliaLang#54489.
1 parent 54755ad commit e0656ac

File tree

3 files changed

+71
-3
lines changed

3 files changed

+71
-3
lines changed

NEWS.md

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ New library features
110110
* `invoke` now supports passing a Method instead of a type signature making this interface somewhat more flexible for certain uncommon use cases ([#56692]).
111111
* `invoke` now supports passing a CodeInstance instead of a type, which can enable
112112
certain compiler plugin workflows ([#56660]).
113+
* `sort` now supports `NTuple`s ([#54494])
113114

114115
Standard library changes
115116
------------------------

base/sort.jl

+40-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ module Sort
44

55
using Base.Order
66

7-
using Base: copymutable, midpoint, require_one_based_indexing, uinttype,
7+
using Base: copymutable, midpoint, require_one_based_indexing, uinttype, tail,
88
sub_with_overflow, add_with_overflow, OneTo, BitSigned, BitIntegerType, top_set_bit
99

1010
import Base:
@@ -1482,8 +1482,9 @@ InitialOptimizations(next) = SubArrayOptimization(
14821482
`DefaultStable` is an algorithm which indicates that a fast, general purpose sorting
14831483
algorithm should be used, but does not specify exactly which algorithm.
14841484
1485-
Currently, it is composed of two parts: the [`InitialOptimizations`](@ref) and a hybrid of
1486-
Radix, Insertion, Counting, Quick sorts.
1485+
Currently, when sorting short NTuples, this is an unrolled mergesort, and otherwise it is
1486+
composed of two parts: the [`InitialOptimizations`](@ref) and a hybrid of Radix, Insertion,
1487+
Counting, Quick sorts.
14871488
14881489
We begin with MissingOptimization because it has no runtime cost when it is not
14891490
triggered and can enable other optimizations to be applied later. For example,
@@ -1619,6 +1620,7 @@ defalg(v::AbstractArray) = DEFAULT_STABLE
16191620
defalg(v::AbstractArray{<:Union{Number, Missing}}) = DEFAULT_UNSTABLE
16201621
defalg(v::AbstractArray{Missing}) = DEFAULT_UNSTABLE # for method disambiguation
16211622
defalg(v::AbstractArray{Union{}}) = DEFAULT_UNSTABLE # for method disambiguation
1623+
defalg(v::NTuple) = DEFAULT_STABLE
16221624

16231625
"""
16241626
sort!(v; alg::Base.Sort.Algorithm=Base.Sort.defalg(v), lt=isless, by=identity, rev::Bool=false, order::Base.Order.Ordering=Base.Order.Forward)
@@ -1757,6 +1759,41 @@ julia> v
17571759
"""
17581760
sort(v::AbstractVector; kws...) = sort!(copymutable(v); kws...)
17591761

1762+
function sort(x::NTuple;
1763+
alg::Algorithm=defalg(x),
1764+
lt=isless,
1765+
by=identity,
1766+
rev::Union{Bool,Nothing}=nothing,
1767+
order::Ordering=Forward,
1768+
scratch::Union{Vector, Nothing}=nothing)
1769+
# Can't do this check with type parameters because of https://github.com/JuliaLang/julia/issues/56698
1770+
scratch === nothing || eltype(x) == eltype(scratch) || throw(ArgumentError("scratch has the wrong eltype"))
1771+
_sort(x, alg, ord(lt,by,rev,order), (;scratch))::typeof(x)
1772+
end
1773+
# Folks who want to hack internals can define a new _sort(x::NTuple, ::TheirAlg, o::Ordering)
1774+
# or _sort(x::NTuple{N, TheirType}, ::DefaultStable, o::Ordering) where N
1775+
function _sort(x::NTuple, a::Union{DefaultStable, DefaultUnstable}, o::Ordering, kw)
1776+
# The unrolled tuple sort is prohibitively slow to compile for length > 9.
1777+
# See https://github.com/JuliaLang/julia/pull/46104#issuecomment-1435688502 for benchmarks
1778+
if length(x) > 9
1779+
v = copymutable(x)
1780+
_sort!(v, a, o, kw)
1781+
typeof(x)(v)
1782+
else
1783+
_mergesort(x, o)
1784+
end
1785+
end
1786+
_mergesort(x::Union{NTuple{0}, NTuple{1}}, o::Ordering) = x
1787+
function _mergesort(x::NTuple, o::Ordering)
1788+
a, b = Base.IteratorsMD.split(x, Val(length(x)>>1))
1789+
merge(_mergesort(a, o), _mergesort(b, o), o)
1790+
end
1791+
merge(x::NTuple, y::NTuple{0}, o::Ordering) = x
1792+
merge(x::NTuple{0}, y::NTuple, o::Ordering) = y
1793+
merge(x::NTuple{0}, y::NTuple{0}, o::Ordering) = x # Method ambiguity
1794+
merge(x::NTuple, y::NTuple, o::Ordering) =
1795+
(lt(o, y[1], x[1]) ? (y[1], merge(x, tail(y), o)...) : (x[1], merge(tail(x), y, o)...))
1796+
17601797
## partialsortperm: the permutation to sort the first k elements of an array ##
17611798

17621799
"""

test/sorting.jl

+30
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,22 @@ end
9494
vcat(2000, (x:x+99 for x in 1900:-100:100)..., 1:99)
9595
end
9696

97+
function tuple_sort_test(x)
98+
@test issorted(sort(x))
99+
length(x) > 9 && return # length > 9 uses a vector fallback
100+
@test 0 == @allocated sort(x)
101+
end
102+
@testset "sort(::NTuple)" begin
103+
@test sort(()) == ()
104+
@test sort((9,8,3,3,6,2,0,8)) == (0,2,3,3,6,8,8,9)
105+
@test sort((9,8,3,3,6,2,0,8), by=x->x÷3) == (2,0,3,3,8,6,8,9)
106+
for i in 1:40
107+
tuple_sort_test(rand(NTuple{i, Float64}))
108+
end
109+
@test_throws MethodError sort((1,2,3.0))
110+
@test Base.infer_return_type(sort, Tuple{Tuple{Vararg{Int}}}) == Tuple{Vararg{Int}}
111+
end
112+
97113
@testset "partialsort" begin
98114
@test partialsort([3,6,30,1,9],3) == 6
99115
@test partialsort([3,6,30,1,9],3:4) == [6,9]
@@ -913,6 +929,20 @@ end
913929
end
914930
@test sort([1,2,3], alg=MySecondAlg()) == [9,9,9]
915931
@test all(sort(v, alg=Base.Sort.InitialOptimizations(MySecondAlg())) .=== vcat(fill(9, 100), fill(missing, 10)))
932+
933+
# Tuple extensions (custom alg)
934+
@test_throws MethodError sort((1,2,3), alg=MyFirstAlg())
935+
Base.Sort._sort(v::NTuple, ::MyFirstAlg, o::Base.Order.Ordering, kw) = (17,2,9)
936+
@test sort((1,2,3), alg=MyFirstAlg()) == (17,2,9)
937+
938+
struct TupleFoo
939+
x::Int
940+
end
941+
942+
# Tuple extensions (custom type)
943+
@test_throws MethodError sort(TupleFoo.((3,1,2)))
944+
Base.Sort._sort(v::NTuple{N, TupleFoo}, ::Base.Sort.DefaultStable, o::Base.Order.Ordering, kw) where N = v
945+
@test sort(TupleFoo.((3,1,2))) === TupleFoo.((3,1,2))
916946
end
917947

918948
@testset "sort!(v, lo, hi, alg, order)" begin

0 commit comments

Comments
 (0)