Skip to content

Commit

Permalink
feat: Implement sorted flags for struct series (#21290)
Browse files Browse the repository at this point in the history
  • Loading branch information
JakubValtar authored Feb 19, 2025
1 parent 43369fb commit 4e286d8
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 18 deletions.
25 changes: 9 additions & 16 deletions crates/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use compare_inner::NonNull;
use rayon::prelude::*;
pub use slice::*;

use crate::chunked_array::ops::row_encode::_get_rows_encoded_ca;
use crate::prelude::compare_inner::TotalOrdInner;
use crate::prelude::sort::arg_sort_multiple::*;
use crate::prelude::*;
Expand Down Expand Up @@ -645,26 +644,20 @@ impl ChunkSort<BinaryOffsetType> for BinaryOffsetChunked {
}
}

#[cfg(feature = "dtype-struct")]
impl StructChunked {
pub(crate) fn arg_sort(&self, options: SortOptions) -> IdxCa {
let bin = _get_rows_encoded_ca(
self.name().clone(),
&[self.clone().into_column()],
&[options.descending],
&[options.nulls_last],
)
.unwrap();
bin.arg_sort(Default::default())
}
}

#[cfg(feature = "dtype-struct")]
impl ChunkSort<StructType> for StructChunked {
fn sort_with(&self, mut options: SortOptions) -> ChunkedArray<StructType> {
options.multithreaded &= POOL.current_num_threads() > 1;
let idx = self.arg_sort(options);
unsafe { self.take_unchecked(&idx) }
let mut out = unsafe { self.take_unchecked(&idx) };

let s = if options.descending {
IsSorted::Descending
} else {
IsSorted::Ascending
};
out.set_sorted_flag(s);
out
}

fn sort(&self, descending: bool) -> ChunkedArray<StructType> {
Expand Down
6 changes: 4 additions & 2 deletions crates/polars-core/src/series/implementations/struct_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ impl PrivateSeries for SeriesWrap<StructChunked> {
}

fn _get_flags(&self) -> StatisticsFlags {
StatisticsFlags::empty()
self.0.get_flags()
}

fn _set_flags(&mut self, _flags: StatisticsFlags) {}
fn _set_flags(&mut self, flags: StatisticsFlags) {
self.0.set_flags(flags);
}

// TODO! remove this. Very slow. Asof join should use row-encoding.
unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool {
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/operations/test_is_sorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,13 @@ def test_is_sorted_chunked_select() -> None:
def test_is_sorted_arithmetic_overflow_14106() -> None:
s = pl.Series([0, 200], dtype=pl.UInt8).sort()
assert not (s + 200).is_sorted()


def test_is_sorted_struct() -> None:
s = pl.Series("a", [{"x": 3}, {"x": 1}, {"x": 2}]).sort()
assert s.flags["SORTED_ASC"]
assert not s.flags["SORTED_DESC"]

s = s.sort(descending=True)
assert s.flags["SORTED_DESC"]
assert not s.flags["SORTED_ASC"]

0 comments on commit 4e286d8

Please sign in to comment.