Skip to content

Commit

Permalink
fix: Crash/incorrect group_by/n_unique on categoricals created by (q)…
Browse files Browse the repository at this point in the history
…cut (#16006)
  • Loading branch information
nameexhaustion authored May 2, 2024
1 parent d5cf038 commit ebd8aec
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ impl CategoricalChunked {
self
}

pub fn _with_fast_unique(self, toggle: bool) -> Self {
self.with_fast_unique(toggle)
}

/// Get a reference to the mapping of categorical types to the string values.
pub fn get_rev_map(&self) -> &Arc<RevMapping> {
if let DataType::Categorical(Some(rev_map), _) | DataType::Enum(Some(rev_map), _) =
Expand Down
54 changes: 47 additions & 7 deletions crates/polars-ops/src/series/ops/cut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ fn map_cats(
PartialOrd::gt
};

// Ensure fast unique is only set if all labels were seen.
let mut label_has_value = vec![false; 1 + sorted_breaks.len()];

if include_breaks {
// This is to replicate the behavior of the old buggy version that only worked on series and
// returned a dataframe. That included a column of the right endpoint of the interval. So we
Expand All @@ -33,8 +36,11 @@ fn map_cats(
let mut brk_vals = PrimitiveChunkedBuilder::<Float64Type>::new("brk", s.len());
s_iter
.map(|opt| {
opt.filter(|x| !x.is_nan())
.map(|x| sorted_breaks.partition_point(|v| op(&x, v)))
opt.filter(|x| !x.is_nan()).map(|x| {
let pt = sorted_breaks.partition_point(|v| op(&x, v));
unsafe { *label_has_value.get_unchecked_mut(pt) = true };
pt
})
})
.for_each(|idx| match idx {
None => {
Expand All @@ -47,17 +53,23 @@ fn map_cats(
},
});

let outvals = vec![brk_vals.finish().into_series(), bld.finish().into_series()];
let outvals = vec![
brk_vals.finish().into_series(),
bld.finish()
._with_fast_unique(label_has_value.iter().all(bool::clone))
.into_series(),
];
Ok(StructChunked::new(&out_name, &outvals)?.into_series())
} else {
Ok(bld
.drain_iter_and_finish(s_iter.map(|opt| {
opt.filter(|x| !x.is_nan()).map(|x| unsafe {
labels
.get_unchecked(sorted_breaks.partition_point(|v| op(&x, v)))
.as_str()
opt.filter(|x| !x.is_nan()).map(|x| {
let pt = sorted_breaks.partition_point(|v| op(&x, v));
unsafe { *label_has_value.get_unchecked_mut(pt) = true };
unsafe { labels.get_unchecked(pt).as_str() }
})
}))
._with_fast_unique(label_has_value.iter().all(bool::clone))
.into_series())
}
}
Expand Down Expand Up @@ -145,3 +157,31 @@ pub fn qcut(

map_cats(&s, &cut_labels, &qbreaks, left_closed, include_breaks)
}

mod test {
#[test]
fn test_map_cats_fast_unique() {
// This test is here to check the fast unique flag is set when it can be
// as it is not visible to Python.
use polars_core::prelude::*;

use super::map_cats;

let s = Series::new("x", &[1, 2, 3, 4, 5]);

let labels = &["a", "b", "c"].map(str::to_owned);
let breaks = &[2.0, 4.0];
let left_closed = false;

let include_breaks = false;
let out = map_cats(&s, labels, breaks, left_closed, include_breaks).unwrap();
let out = out.categorical().unwrap();
assert!(out._can_fast_unique());

let include_breaks = true;
let out = map_cats(&s, labels, breaks, left_closed, include_breaks).unwrap();
let out = out.struct_().unwrap().fields()[1].clone();
let out = out.categorical().unwrap();
assert!(out._can_fast_unique());
}
}
44 changes: 44 additions & 0 deletions py-polars/tests/unit/operations/test_cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,47 @@ def test_cut_bin_name_in_agg_context() -> None:
)
schema = pl.Struct({"brk": pl.Float64, "a_bin": pl.Categorical("physical")})
assert df.schema == {"cut": schema, "qcut": schema, "qcut_uniform": schema}


@pytest.mark.parametrize(
("breaks", "expected_labels", "expected_physical", "expected_unique"),
[
(
[2, 4],
pl.Series("x", ["(-inf, 2]", "(-inf, 2]", "(2, 4]", "(2, 4]", "(4, inf]"]),
pl.Series("x", [0, 0, 1, 1, 2], dtype=pl.UInt32),
3,
),
(
[99, 101],
pl.Series("x", 5 * ["(-inf, 99]"]),
pl.Series("x", 5 * [0], dtype=pl.UInt32),
1,
),
],
)
def test_cut_fast_unique_15981(
breaks: list[int],
expected_labels: pl.Series,
expected_physical: pl.Series,
expected_unique: int,
) -> None:
s = pl.Series("x", [1, 2, 3, 4, 5])

include_breaks = False
s_cut = s.cut(breaks, include_breaks=include_breaks)

assert_series_equal(s_cut.cast(pl.String), expected_labels)
assert_series_equal(s_cut.to_physical(), expected_physical)
assert s_cut.n_unique() == s_cut.to_physical().n_unique() == expected_unique
s_cut.to_frame().group_by(s.name).len()

include_breaks = True
s_cut = (
s.cut(breaks, include_breaks=include_breaks).struct.field("category").alias("x")
)

assert_series_equal(s_cut.cast(pl.String), expected_labels)
assert_series_equal(s_cut.to_physical(), expected_physical)
assert s_cut.n_unique() == s_cut.to_physical().n_unique() == expected_unique
s_cut.to_frame().group_by(s.name).len()

0 comments on commit ebd8aec

Please sign in to comment.