diff --git a/crates/polars-io/src/cloud/credential_provider.rs b/crates/polars-io/src/cloud/credential_provider.rs index 989e0318120c..e6de837488c1 100644 --- a/crates/polars-io/src/cloud/credential_provider.rs +++ b/crates/polars-io/src/cloud/credential_provider.rs @@ -343,9 +343,6 @@ impl serde::Serialize for PlCredentialProvider { { use serde::ser::Error; - // TODO: - // * Add magic bytes here to indicate a python function - // * Check the Python version on deserialize #[cfg(feature = "python")] if let PlCredentialProvider::Python(v) = self { return v.serialize(serializer); diff --git a/crates/polars-ops/src/chunked_array/list/to_struct.rs b/crates/polars-ops/src/chunked_array/list/to_struct.rs index 6676de3983db..fad1bcebb9a1 100644 --- a/crates/polars-ops/src/chunked_array/list/to_struct.rs +++ b/crates/polars-ops/src/chunked_array/list/to_struct.rs @@ -5,82 +5,220 @@ use polars_utils::pl_str::PlSmallStr; use super::*; -#[derive(Copy, Clone, Debug)] +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum ListToStructArgs { + FixedWidth(Arc<[PlSmallStr]>), + InferWidth { + infer_field_strategy: ListToStructWidthStrategy, + get_index_name: Option, + /// If this is 0, it means unbounded. + max_fields: usize, + }, +} + +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum ListToStructWidthStrategy { FirstNonNull, MaxWidth, } -fn det_n_fields(ca: &ListChunked, n_fields: ListToStructWidthStrategy) -> usize { - match n_fields { - ListToStructWidthStrategy::MaxWidth => { - let mut max = 0; - - ca.downcast_iter().for_each(|arr| { - let offsets = arr.offsets().as_slice(); - let mut last = offsets[0]; - for o in &offsets[1..] { - let len = (*o - last) as usize; - max = std::cmp::max(max, len); - last = *o; +impl ListToStructArgs { + pub fn get_output_dtype(&self, input_dtype: &DataType) -> PolarsResult { + let DataType::List(inner_dtype) = input_dtype else { + polars_bail!( + InvalidOperation: + "attempted list to_struct on non-list dtype: {}", + input_dtype + ); + }; + let inner_dtype = inner_dtype.as_ref(); + + match self { + Self::FixedWidth(names) => Ok(DataType::Struct( + names + .iter() + .map(|x| Field::new(x.clone(), inner_dtype.clone())) + .collect::>(), + )), + Self::InferWidth { + get_index_name, + max_fields, + .. + } if *max_fields > 0 => { + let get_index_name_func = get_index_name.as_ref().map_or( + &_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr, + |x| x.0.as_ref(), + ); + Ok(DataType::Struct( + (0..*max_fields) + .map(|i| Field::new(get_index_name_func(i), inner_dtype.clone())) + .collect::>(), + )) + }, + Self::InferWidth { .. } => Ok(DataType::Unknown(UnknownKind::Any)), + } + } + + fn det_n_fields(&self, ca: &ListChunked) -> usize { + match self { + Self::FixedWidth(v) => v.len(), + Self::InferWidth { + infer_field_strategy, + max_fields, + .. + } => { + let inferred = match infer_field_strategy { + ListToStructWidthStrategy::MaxWidth => { + let mut max = 0; + + ca.downcast_iter().for_each(|arr| { + let offsets = arr.offsets().as_slice(); + let mut last = offsets[0]; + for o in &offsets[1..] { + let len = (*o - last) as usize; + max = std::cmp::max(max, len); + last = *o; + } + }); + max + }, + ListToStructWidthStrategy::FirstNonNull => { + let mut len = 0; + for arr in ca.downcast_iter() { + let offsets = arr.offsets().as_slice(); + let mut last = offsets[0]; + for o in &offsets[1..] { + len = (*o - last) as usize; + if len > 0 { + break; + } + last = *o; + } + if len > 0 { + break; + } + } + len + }, + }; + + if *max_fields > 0 { + inferred.min(*max_fields) + } else { + inferred } - }); - max - }, - ListToStructWidthStrategy::FirstNonNull => { - let mut len = 0; - for arr in ca.downcast_iter() { - let offsets = arr.offsets().as_slice(); - let mut last = offsets[0]; - for o in &offsets[1..] { - len = (*o - last) as usize; - if len > 0 { - break; - } - last = *o; + }, + } + } + + fn set_output_names(&self, columns: &mut [Series]) { + match self { + Self::FixedWidth(v) => { + assert_eq!(columns.len(), v.len()); + + for (c, name) in columns.iter_mut().zip(v.iter()) { + c.rename(name.clone()); } - if len > 0 { - break; + }, + Self::InferWidth { get_index_name, .. } => { + let get_index_name_func = get_index_name.as_ref().map_or( + &_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr, + |x| x.0.as_ref(), + ); + + for (i, c) in columns.iter_mut().enumerate() { + c.rename(get_index_name_func(i)); } - } - len - }, + }, + } + } +} + +#[derive(Clone)] +pub struct NameGenerator(pub Arc PlSmallStr + Send + Sync>); + +impl NameGenerator { + pub fn from_func(func: impl Fn(usize) -> PlSmallStr + Send + Sync + 'static) -> Self { + Self(Arc::new(func)) + } +} + +impl std::fmt::Debug for NameGenerator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "list::to_struct::NameGenerator function at 0x{:016x}", + self.0.as_ref() as *const _ as *const () as usize + ) + } +} + +impl Eq for NameGenerator {} + +impl PartialEq for NameGenerator { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) } } -pub type NameGenerator = Arc PlSmallStr + Send + Sync>; +impl std::hash::Hash for NameGenerator { + fn hash(&self, state: &mut H) { + state.write_usize(Arc::as_ptr(&self.0) as *const () as usize) + } +} pub fn _default_struct_name_gen(idx: usize) -> PlSmallStr { format_pl_smallstr!("field_{idx}") } pub trait ToStruct: AsList { - fn to_struct( - &self, - n_fields: ListToStructWidthStrategy, - name_generator: Option, - ) -> PolarsResult { + fn to_struct(&self, args: &ListToStructArgs) -> PolarsResult { let ca = self.as_list(); - let n_fields = det_n_fields(ca, n_fields); + let n_fields = args.det_n_fields(ca); - let name_generator = name_generator - .as_deref() - .unwrap_or(&_default_struct_name_gen); - - let fields = POOL.install(|| { + let mut fields = POOL.install(|| { (0..n_fields) .into_par_iter() - .map(|i| { - ca.lst_get(i as i64, true).map(|mut s| { - s.rename(name_generator(i)); - s - }) - }) + .map(|i| ca.lst_get(i as i64, true)) .collect::>>() })?; + args.set_output_names(&mut fields); + StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter()) } } impl ToStruct for ListChunked {} + +#[cfg(feature = "serde")] +mod _serde_impl { + use super::*; + + impl serde::Serialize for NameGenerator { + fn serialize(&self, _serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::Error; + Err(S::Error::custom( + "cannot serialize name generator function for to_struct, \ + consider passing a list of field names instead.", + )) + } + } + + impl<'de> serde::Deserialize<'de> for NameGenerator { + fn deserialize(_deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + Err(D::Error::custom( + "invalid data: attempted to deserialize list::to_struct::NameGenerator", + )) + } + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index b9a07d97341a..ddf8fb1fff20 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -4,7 +4,7 @@ use polars_ops::chunked_array::list::*; use super::*; use crate::{map, map_as_slice, wrap}; -#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] +#[derive(Clone, Eq, PartialEq, Hash, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum ListFunction { Concat, @@ -56,6 +56,8 @@ pub enum ListFunction { Join(bool), #[cfg(feature = "dtype-array")] ToArray(usize), + #[cfg(feature = "list_to_struct")] + ToStruct(ListToStructArgs), } impl ListFunction { @@ -103,6 +105,8 @@ impl ListFunction { #[cfg(feature = "dtype-array")] ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)), NUnique => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "list_to_struct")] + ToStruct(args) => mapper.try_map_dtype(|x| args.get_output_dtype(x)), } } } @@ -174,6 +178,8 @@ impl Display for ListFunction { Join(_) => "join", #[cfg(feature = "dtype-array")] ToArray(_) => "to_array", + #[cfg(feature = "list_to_struct")] + ToStruct(_) => "to_struct", }; write!(f, "list.{name}") } @@ -235,6 +241,8 @@ impl From for SpecialEq> { #[cfg(feature = "dtype-array")] ToArray(width) => map!(to_array, width), NUnique => map!(n_unique), + #[cfg(feature = "list_to_struct")] + ToStruct(args) => map!(to_struct, &args), } } } @@ -650,6 +658,11 @@ pub(super) fn to_array(s: &Column, width: usize) -> PolarsResult { s.cast(&array_dtype) } +#[cfg(feature = "list_to_struct")] +pub(super) fn to_struct(s: &Column, args: &ListToStructArgs) -> PolarsResult { + Ok(s.list()?.to_struct(args)?.into_series().into()) +} + pub(super) fn n_unique(s: &Column) -> PolarsResult { Ok(s.list()?.lst_n_unique()?.into_column()) } diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index fb0c7a83b463..bc4c468eadaf 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -1,6 +1,3 @@ -#[cfg(feature = "list_to_struct")] -use std::sync::RwLock; - use polars_core::prelude::*; #[cfg(feature = "diff")] use polars_core::series::ops::NullBehavior; @@ -281,50 +278,9 @@ impl ListNameSpace { /// an `upper_bound` of struct fields that will be set. /// If this is incorrectly downstream operation may fail. For instance an `all().sum()` expression /// will look in the current schema to determine which columns to select. - pub fn to_struct( - self, - n_fields: ListToStructWidthStrategy, - name_generator: Option, - upper_bound: usize, - ) -> Expr { - // heap allocate the output type and fill it later - let out_dtype = Arc::new(RwLock::new(None::)); - + pub fn to_struct(self, args: ListToStructArgs) -> Expr { self.0 - .map( - move |s| { - s.list()? - .to_struct(n_fields, name_generator.clone()) - .map(|s| Some(s.into_column())) - }, - // we don't yet know the fields - GetOutput::map_dtype(move |dt: &DataType| { - polars_ensure!(matches!(dt, DataType::List(_)), SchemaMismatch: "expected 'List' as input to 'list.to_struct' got {}", dt); - let out = out_dtype.read().unwrap(); - match out.as_ref() { - // dtype already set - Some(dt) => Ok(dt.clone()), - // dtype still unknown, set it - None => { - drop(out); - let mut lock = out_dtype.write().unwrap(); - - let inner = dt.inner_dtype().unwrap(); - let fields = (0..upper_bound) - .map(|i| { - let name = _default_struct_name_gen(i); - Field::new(name, inner.clone()) - }) - .collect(); - let dt = DataType::Struct(fields); - - *lock = Some(dt.clone()); - Ok(dt) - }, - } - }), - ) - .with_fmt("list.to_struct") + .map_private(FunctionExpr::ListExpr(ListFunction::ToStruct(args))) } #[cfg(feature = "is_in")] diff --git a/crates/polars-python/src/expr/list.rs b/crates/polars-python/src/expr/list.rs index 1bd087144634..af3be10449b1 100644 --- a/crates/polars-python/src/expr/list.rs +++ b/crates/polars-python/src/expr/list.rs @@ -4,6 +4,7 @@ use polars::prelude::*; use polars::series::ops::NullBehavior; use polars_utils::pl_str::PlSmallStr; use pyo3::prelude::*; +use pyo3::types::PySequence; use crate::conversion::Wrap; use crate::PyExpr; @@ -214,20 +215,39 @@ impl PyExpr { upper_bound: usize, ) -> PyResult { let name_gen = name_gen.map(|lambda| { - Arc::new(move |idx: usize| { + NameGenerator::from_func(move |idx: usize| { Python::with_gil(|py| { let out = lambda.call1(py, (idx,)).unwrap(); let out: PlSmallStr = out.extract::>(py).unwrap().as_ref().into(); out }) - }) as NameGenerator + }) }); Ok(self .inner .clone() .list() - .to_struct(width_strat.0, name_gen, upper_bound) + .to_struct(ListToStructArgs::InferWidth { + infer_field_strategy: width_strat.0, + get_index_name: name_gen, + max_fields: upper_bound, + }) + .into()) + } + + #[pyo3(signature = (names))] + fn list_to_struct_fixed_width(&self, names: Bound<'_, PySequence>) -> PyResult { + Ok(self + .inner + .clone() + .list() + .to_struct(ListToStructArgs::FixedWidth( + names + .iter()? + .map(|x| Ok(x?.extract::>()?.0)) + .collect::>>()?, + )) .into()) } diff --git a/py-polars/polars/_typing.py b/py-polars/polars/_typing.py index da5c30ef996e..67a06a2c689e 100644 --- a/py-polars/polars/_typing.py +++ b/py-polars/polars/_typing.py @@ -160,9 +160,7 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: RollingInterpolationMethod: TypeAlias = Literal[ "nearest", "higher", "lower", "midpoint", "linear" ] # QuantileInterpolOptions -ToStructStrategy: TypeAlias = Literal[ - "first_non_null", "max_width" -] # ListToStructWidthStrategy +ListToStructWidthStrategy: TypeAlias = Literal["first_non_null", "max_width"] # The following have no equivalent on the Rust side ConcatMethod = Literal[ diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 48b4d1da9c49..4d239460d6b5 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -16,8 +16,8 @@ from polars._typing import ( IntoExpr, IntoExprColumn, + ListToStructWidthStrategy, NullBehavior, - ToStructStrategy, ) @@ -1092,7 +1092,7 @@ def to_array(self, width: int) -> Expr: def to_struct( self, - n_field_strategy: ToStructStrategy = "first_non_null", + n_field_strategy: ListToStructWidthStrategy = "first_non_null", fields: Sequence[str] | Callable[[int], str] | None = None, upper_bound: int = 0, ) -> Expr: @@ -1180,9 +1180,8 @@ def to_struct( [{'n': {'one': 0, 'two': 1}}, {'n': {'one': 2, 'two': 3}}] """ if isinstance(fields, Sequence): - field_names = list(fields) - pyexpr = self._pyexpr.list_to_struct(n_field_strategy, None, upper_bound) - return wrap_expr(pyexpr).struct.rename_fields(field_names) + pyexpr = self._pyexpr.list_to_struct_fixed_width(fields) + return wrap_expr(pyexpr) else: pyexpr = self._pyexpr.list_to_struct(n_field_strategy, fields, upper_bound) return wrap_expr(pyexpr) diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index cf70f5225f56..0c4b08982606 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -14,8 +14,8 @@ from polars._typing import ( IntoExpr, IntoExprColumn, + ListToStructWidthStrategy, NullBehavior, - ToStructStrategy, ) from polars.polars import PySeries @@ -855,7 +855,7 @@ def to_array(self, width: int) -> Series: def to_struct( self, - n_field_strategy: ToStructStrategy = "first_non_null", + n_field_strategy: ListToStructWidthStrategy = "first_non_null", fields: Callable[[int], str] | Sequence[str] | None = None, ) -> Series: """ diff --git a/py-polars/tests/unit/operations/namespaces/list/test_list.py b/py-polars/tests/unit/operations/namespaces/list/test_list.py index 78bda8a3637d..63beb278b392 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_list.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_list.py @@ -7,7 +7,11 @@ import pytest import polars as pl -from polars.exceptions import ComputeError, OutOfBoundsError, SchemaError +from polars.exceptions import ( + ComputeError, + OutOfBoundsError, + SchemaError, +) from polars.testing import assert_frame_equal, assert_series_equal @@ -643,6 +647,26 @@ def test_list_to_struct() -> None: {"n": {"one": 0, "two": 1, "three": None}}, ] + q = df.lazy().select( + pl.col("n").list.to_struct(fields=["a", "b"]).struct.field("a") + ) + + assert_frame_equal(q.collect(), pl.DataFrame({"a": [0, 0]})) + + # Check that: + # * Specifying an upper bound calls the field name getter function to + # retrieve the lazy schema + # * The upper bound is respected during execution + q = df.lazy().select( + pl.col("n").list.to_struct(fields=str, upper_bound=2).struct.unnest() + ) + assert q.collect_schema() == {"0": pl.Int64, "1": pl.Int64} + assert_frame_equal(q.collect(), pl.DataFrame({"0": [0, 0], "1": [1, 1]})) + + assert df.lazy().select(pl.col("n").list.to_struct()).collect_schema() == { + "n": pl.Unknown + } + def test_select_from_list_to_struct_11143() -> None: ldf = pl.LazyFrame({"some_col": [[1.0, 2.0], [1.5, 3.0]]}) diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index c730ee8d30a7..7a1c2111399a 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -696,7 +696,7 @@ def test_no_panic_pandas_nat() -> None: def test_list_to_struct_invalid_type() -> None: - with pytest.raises(pl.exceptions.SchemaError): + with pytest.raises(pl.exceptions.InvalidOperationError): pl.DataFrame({"a": 1}).select(pl.col("a").list.to_struct())