Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Add proper tests for row encoding #19843

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/polars-python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ polars-mem-engine = { workspace = true }
polars-ops = { workspace = true, features = ["bitwise"] }
polars-parquet = { workspace = true, optional = true }
polars-plan = { workspace = true }
polars-row = { workspace = true }
polars-time = { workspace = true }
polars-utils = { workspace = true }

Expand Down
42 changes: 42 additions & 0 deletions crates/polars-python/src/dataframe/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -710,4 +710,46 @@ impl PyDataFrame {
let cap = md_cols.capacity();
(ptr as usize, len, cap)
}

/// Internal utility function to allow direct access to the row encoding from python.
#[pyo3(signature = (fields))]
fn _row_encode<'py>(
&'py self,
py: Python<'py>,
fields: Vec<(bool, bool, bool)>,
) -> PyResult<PySeries> {
py.allow_threads(|| {
let mut df = self.df.clone();
df.rechunk_mut();

assert_eq!(df.width(), fields.len());

let chunks = df
.get_columns()
.iter()
.map(|c| c.as_materialized_series().to_physical_repr().chunks()[0].to_boxed())
.collect::<Vec<_>>();
let fields = fields
.into_iter()
.map(
|(descending, nulls_last, no_order)| polars_row::EncodingField {
descending,
nulls_last,
no_order,
},
)
.collect::<Vec<_>>();

let rows = polars_row::convert_columns(&chunks, &fields);

Ok(unsafe {
Series::from_chunks_and_dtype_unchecked(
PlSmallStr::from_static("row_enc"),
vec![rows.into_array().boxed()],
&DataType::BinaryOffset,
)
}
.into())
})
}
}
1 change: 1 addition & 0 deletions crates/polars-python/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{PyExpr, Wrap};

// Don't change the order of these!
#[repr(u8)]
#[derive(Clone)]
pub(crate) enum PyDataType {
Int8,
Int16,
Expand Down
62 changes: 62 additions & 0 deletions crates/polars-python/src/series/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,68 @@ impl PySeries {
.map_err(PyPolarsErr::from)?;
Ok(out.into())
}

/// Internal utility function to allow direct access to the row encoding from python.
#[pyo3(signature = (dtypes, fields))]
fn _row_decode<'py>(
&'py self,
py: Python<'py>,
dtypes: Vec<(String, Wrap<DataType>)>,
fields: Vec<(bool, bool, bool)>,
) -> PyResult<PyDataFrame> {
py.allow_threads(|| {
assert_eq!(dtypes.len(), fields.len());

let fields = fields
.into_iter()
.map(
|(descending, nulls_last, no_order)| polars_row::EncodingField {
descending,
nulls_last,
no_order,
},
)
.collect::<Vec<_>>();

// The polars-row crate expects the physical arrow types.
let arrow_dtypes = dtypes
.iter()
.map(|(_, dtype)| dtype.0.to_physical().to_arrow(CompatLevel::newest()))
.collect::<Vec<_>>();

// Get the BinaryOffset array.
let arr = self.series.rechunk();
let arr = arr.binary_offset().map_err(PyPolarsErr::from)?;
assert_eq!(arr.chunks().len(), 1);
let mut values = arr
.downcast_iter()
.next()
.unwrap()
.values_iter()
.collect::<Vec<&[u8]>>();

let columns = PyResult::Ok(unsafe {
polars_row::decode::decode_rows(&mut values, &fields, &arrow_dtypes)
})?;

// Construct a DataFrame from the result.
let columns = columns
.into_iter()
.zip(dtypes)
.map(|(arr, (name, dtype))| {
unsafe {
Series::from_chunks_and_dtype_unchecked(
PlSmallStr::from(name),
vec![arr],
&dtype.0,
)
}
.into_column()
})
.collect::<Vec<_>>();
Ok(DataFrame::new(columns).map_err(PyPolarsErr::from)?.into())
})
}
}

macro_rules! impl_set_with_mask {
Expand Down
17 changes: 17 additions & 0 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -11304,6 +11304,23 @@ def _to_metadata(

return md

def _row_encode(
self,
fields: list[tuple[bool, bool, bool]],
) -> Series:
"""
Row encode the given DataFrame.

This is an internal function not meant for outside consumption and can
be changed or removed at any point in time.

fields have order:
- descending
- nulls_last
- no_order
"""
return pl.Series._from_pyseries(self._df._row_encode(fields))


def _prepare_other_arg(other: Any, length: int | None = None) -> Series:
# if not a series create singleton series such that it will broadcast
Expand Down
18 changes: 18 additions & 0 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -7517,6 +7517,24 @@ def plot(self) -> SeriesPlot:
raise ModuleUpgradeRequiredError(msg)
return SeriesPlot(self)

def _row_decode(
self,
dtypes: Iterable[tuple[str, DataType]], # type: ignore[valid-type]
fields: Iterable[tuple[bool, bool, bool]],
) -> DataFrame:
"""
Row decode the given Series.

This is an internal function not meant for outside consumption and can
be changed or removed at any point in time.

fields have order:
- descending
- nulls_last
- no_order
"""
return pl.DataFrame._from_pydf(self._s._row_decode(list(dtypes), list(fields)))


def _resolve_temporal_dtype(
dtype: PolarsDataType | None,
Expand Down
146 changes: 146 additions & 0 deletions py-polars/tests/unit/test_row_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from __future__ import annotations

import pytest
from hypothesis import given

import polars as pl
from polars.testing import assert_frame_equal
from polars.testing.parametric import dataframes

# @TODO: Deal with no_order
FIELD_COMBS = [
(descending, nulls_last, False)
for descending in [False, True]
for nulls_last in [False, True]
]


def roundtrip_re(
df: pl.DataFrame, fields: list[tuple[bool, bool, bool]] | None = None
) -> None:
if fields is None:
fields = [(False, False, False)] * df.width

row_encoded = df._row_encode(fields)
dtypes = [(c, df.get_column(c).dtype) for c in df.columns]
result = row_encoded._row_decode(dtypes, fields)

assert_frame_equal(df, result)


@given(
df=dataframes(
excluded_dtypes=[
pl.List,
pl.Array,
pl.Struct,
pl.Categorical,
pl.Enum,
]
)
)
@pytest.mark.parametrize("field", FIELD_COMBS)
def test_row_encoding_parametric(
df: pl.DataFrame, field: tuple[bool, bool, bool]
) -> None:
roundtrip_re(df, [field] * df.width)


@pytest.mark.parametrize("field", FIELD_COMBS)
def test_nulls(field: tuple[bool, bool, bool]) -> None:
roundtrip_re(pl.Series("a", [], pl.Null).to_frame(), [field])
roundtrip_re(pl.Series("a", [None], pl.Null).to_frame(), [field])
roundtrip_re(pl.Series("a", [None] * 2, pl.Null).to_frame(), [field])
roundtrip_re(pl.Series("a", [None] * 13, pl.Null).to_frame(), [field])
roundtrip_re(pl.Series("a", [None] * 42, pl.Null).to_frame(), [field])


@pytest.mark.parametrize("field", FIELD_COMBS)
def test_bool(field: tuple[bool, bool, bool]) -> None:
roundtrip_re(pl.Series("a", [], pl.Boolean).to_frame(), [field])
roundtrip_re(pl.Series("a", [False], pl.Boolean).to_frame(), [field])
roundtrip_re(pl.Series("a", [True], pl.Boolean).to_frame(), [field])
roundtrip_re(pl.Series("a", [False, True], pl.Boolean).to_frame(), [field])
roundtrip_re(pl.Series("a", [True, False], pl.Boolean).to_frame(), [field])


@pytest.mark.parametrize(
"dtype",
[
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
],
)
@pytest.mark.parametrize("field", FIELD_COMBS)
def test_int(dtype: pl.DataType, field: tuple[bool, bool, bool]) -> None:
min = pl.select(x=dtype.min()).item() # type: ignore[attr-defined]
max = pl.select(x=dtype.max()).item() # type: ignore[attr-defined]

roundtrip_re(pl.Series("a", [], dtype).to_frame(), [field])
roundtrip_re(pl.Series("a", [0], dtype).to_frame(), [field])
roundtrip_re(pl.Series("a", [min], dtype).to_frame(), [field])
roundtrip_re(pl.Series("a", [max], dtype).to_frame(), [field])

roundtrip_re(pl.Series("a", [1, 2, 3], dtype).to_frame(), [field])
roundtrip_re(pl.Series("a", [0, 1, 2, 3], dtype).to_frame(), [field])
roundtrip_re(pl.Series("a", [min, 0, max], dtype).to_frame(), [field])


@pytest.mark.parametrize(
"dtype",
[
pl.Float32,
pl.Float64,
],
)
@pytest.mark.parametrize("field", FIELD_COMBS)
def test_float(dtype: pl.DataType, field: tuple[bool, bool, bool]) -> None:
inf = float("inf")
inf_b = float("-inf")

roundtrip_re(pl.Series("a", [], dtype).to_frame(), [field])
roundtrip_re(pl.Series("a", [0.0], dtype).to_frame(), [field])
roundtrip_re(pl.Series("a", [inf], dtype).to_frame(), [field])
roundtrip_re(pl.Series("a", [-inf_b], dtype).to_frame(), [field])

roundtrip_re(pl.Series("a", [1.0, 2.0, 3.0], dtype).to_frame(), [field])
roundtrip_re(pl.Series("a", [0.0, 1.0, 2.0, 3.0], dtype).to_frame(), [field])
roundtrip_re(pl.Series("a", [inf, 0, -inf_b], dtype).to_frame(), [field])


@pytest.mark.parametrize("field", FIELD_COMBS)
def test_str(field: tuple[bool, bool, bool]) -> None:
roundtrip_re(pl.Series("a", [], pl.String).to_frame(), [field])
roundtrip_re(pl.Series("a", [""], pl.String).to_frame(), [field])

roundtrip_re(pl.Series("a", ["a", "b", "c"], pl.String).to_frame(), [field])
roundtrip_re(pl.Series("a", ["", "a", "b", "c"], pl.String).to_frame(), [field])

roundtrip_re(
pl.Series("a", ["different", "length", "strings"], pl.String).to_frame(),
[field],
)
roundtrip_re(
pl.Series(
"a", ["different", "", "length", "", "strings"], pl.String
).to_frame(),
[field],
)


# def test_struct() -> None:
# # @TODO: How do we deal with zero-field structs?
# # roundtrip_re(pl.Series('a', [], pl.Struct({})).to_frame())
# # roundtrip_re(pl.Series('a', [{}], pl.Struct({})).to_frame())
# roundtrip_re(pl.Series("a", [{"x": 1}], pl.Struct({"x": pl.Int32})).to_frame())
# roundtrip_re(
# pl.Series(
# "a", [{"x": 1}, {"y": 2}], pl.Struct({"x": pl.Int32, "y": pl.Int32})
# ).to_frame()
# )