Skip to content

Commit

Permalink
feat: Implement merge_sorted for struct (#21205)
Browse files Browse the repository at this point in the history
  • Loading branch information
JakubValtar authored Feb 13, 2025
1 parent 918bb1a commit 9ccfddd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 6 deletions.
21 changes: 16 additions & 5 deletions crates/polars-ops/src/frame/join/merge_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub fn _merge_sorted_dfs(
return Ok(right.clone());
}

let merge_indicator = series_to_merge_indicator(left_s, right_s);
let merge_indicator = series_to_merge_indicator(left_s, right_s)?;
let new_columns = left
.get_columns()
.iter()
Expand Down Expand Up @@ -90,7 +90,10 @@ fn merge_series(lhs: &Series, rhs: &Series, merge_indicator: &[bool]) -> PolarsR
.fields_as_series()
.iter()
.zip(rhs.fields_as_series())
.map(|(lhs, rhs)| merge_series(lhs, &rhs, merge_indicator))
.map(|(lhs, rhs)| {
merge_series(lhs, &rhs, merge_indicator)
.map(|merged| merged.with_name(lhs.name().clone()))
})
.collect::<PolarsResult<Vec<_>>>()?;
StructChunked::from_series(PlSmallStr::EMPTY, new_fields[0].len(), new_fields.iter())
.unwrap()
Expand Down Expand Up @@ -139,11 +142,11 @@ where
unsafe { iter.trust_my_length(total_len).collect_trusted() }
}

fn series_to_merge_indicator(lhs: &Series, rhs: &Series) -> Vec<bool> {
fn series_to_merge_indicator(lhs: &Series, rhs: &Series) -> PolarsResult<Vec<bool>> {
let lhs_s = lhs.to_physical_repr().into_owned();
let rhs_s = rhs.to_physical_repr().into_owned();

match lhs_s.dtype() {
let out = match lhs_s.dtype() {
DataType::Boolean => {
let lhs = lhs_s.bool().unwrap();
let rhs = rhs_s.bool().unwrap();
Expand All @@ -159,6 +162,13 @@ fn series_to_merge_indicator(lhs: &Series, rhs: &Series) -> Vec<bool> {
let rhs = rhs_s.binary().unwrap();
get_merge_indicator(lhs.into_iter(), rhs.into_iter())
},
#[cfg(feature = "dtype-struct")]
DataType::Struct(_) => {
let options = SortOptions::default();
let lhs = lhs_s.struct_().unwrap().get_row_encoded(options)?;
let rhs = rhs_s.struct_().unwrap().get_row_encoded(options)?;
get_merge_indicator(lhs.into_iter(), rhs.into_iter())
},
_ => {
with_match_physical_numeric_polars_type!(lhs_s.dtype(), |$T| {
let lhs: &ChunkedArray<$T> = lhs_s.as_ref().as_ref().as_ref();
Expand All @@ -168,7 +178,8 @@ fn series_to_merge_indicator(lhs: &Series, rhs: &Series) -> Vec<bool> {

})
},
}
};
Ok(out)
}

// get a boolean values, left: true, right: false
Expand Down
27 changes: 26 additions & 1 deletion py-polars/tests/unit/operations/test_merge_sorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,36 @@ def test_merge_sorted_parametric_string(lhs: pl.Series, rhs: pl.Series) -> None:
assert_series_equal(merge_sorted, append_sorted)


@given(
lhs=series(
name="a",
allowed_dtypes=[
pl.Struct({"x": pl.Int32, "y": pl.Struct({"x": pl.Int8, "y": pl.Int8})})
],
allow_null=False,
), # Nulls see: https://github.com/pola-rs/polars/issues/20991
rhs=series(
name="a",
allowed_dtypes=[
pl.Struct({"x": pl.Int32, "y": pl.Struct({"x": pl.Int8, "y": pl.Int8})})
],
allow_null=False,
), # Nulls see: https://github.com/pola-rs/polars/issues/20991
)
def test_merge_sorted_parametric_struct(lhs: pl.Series, rhs: pl.Series) -> None:
l_df = pl.DataFrame([lhs.sort()])
r_df = pl.DataFrame([rhs.sort()])

merge_sorted = l_df.lazy().merge_sorted(r_df.lazy(), "a").collect().get_column("a")
append_sorted = lhs.append(rhs).sort()

assert_series_equal(merge_sorted, append_sorted)


@given(
s=series(
name="a",
excluded_dtypes=[
pl.Struct, # Bug. See https://github.com/pola-rs/polars/issues/20986
pl.Categorical(
ordering="lexical"
), # Bug. See https://github.com/pola-rs/polars/issues/21025
Expand Down

0 comments on commit 9ccfddd

Please sign in to comment.