diff --git a/crates/polars-arrow/src/array/dictionary/mod.rs b/crates/polars-arrow/src/array/dictionary/mod.rs index 6947c9e071c7..65e7343762ed 100644 --- a/crates/polars-arrow/src/array/dictionary/mod.rs +++ b/crates/polars-arrow/src/array/dictionary/mod.rs @@ -25,7 +25,9 @@ use polars_error::{polars_bail, PolarsResult}; use super::primitive::PrimitiveArray; use super::specification::check_indexes; use super::{new_empty_array, new_null_array, Array}; -use crate::array::dictionary::typed_iterator::{DictValue, DictionaryValuesIterTyped}; +use crate::array::dictionary::typed_iterator::{ + DictValue, DictionaryIterTyped, DictionaryValuesIterTyped, +}; /// Trait denoting [`NativeType`]s that can be used as keys of a dictionary. /// # Safety @@ -241,30 +243,22 @@ impl DictionaryArray { /// /// # Panics /// - /// Panics if the keys of this [`DictionaryArray`] have any null types. - /// If they do [`DictionaryArray::iter_typed`] should be called + /// Panics if the keys of this [`DictionaryArray`] has any nulls. + /// If they do [`DictionaryArray::iter_typed`] should be used. pub fn values_iter_typed(&self) -> PolarsResult> { let keys = &self.keys; assert_eq!(keys.null_count(), 0); let values = self.values.as_ref(); let values = V::downcast_values(values)?; - Ok(unsafe { DictionaryValuesIterTyped::new(keys, values) }) + Ok(DictionaryValuesIterTyped::new(keys, values)) } /// Returns an iterator over the optional values of [`Option`]. - /// - /// # Panics - /// - /// This function panics if the `values` array - pub fn iter_typed( - &self, - ) -> PolarsResult, DictionaryValuesIterTyped, BitmapIter>> - { + pub fn iter_typed(&self) -> PolarsResult> { let keys = &self.keys; let values = self.values.as_ref(); let values = V::downcast_values(values)?; - let values_iter = unsafe { DictionaryValuesIterTyped::new(keys, values) }; - Ok(ZipValidity::new_with_validity(values_iter, self.validity())) + Ok(DictionaryIterTyped::new(keys, values)) } /// Returns the [`ArrowDataType`] of this [`DictionaryArray`] diff --git a/crates/polars-arrow/src/array/dictionary/typed_iterator.rs b/crates/polars-arrow/src/array/dictionary/typed_iterator.rs index 6a543968b98d..87fb0e95bfbd 100644 --- a/crates/polars-arrow/src/array/dictionary/typed_iterator.rs +++ b/crates/polars-arrow/src/array/dictionary/typed_iterator.rs @@ -1,7 +1,7 @@ use polars_error::{polars_err, PolarsResult}; use super::DictionaryKey; -use crate::array::{Array, PrimitiveArray, Utf8Array, Utf8ViewArray}; +use crate::array::{Array, PrimitiveArray, StaticArray, Utf8Array, Utf8ViewArray}; use crate::trusted_len::TrustedLen; use crate::types::Offset; @@ -85,7 +85,8 @@ pub struct DictionaryValuesIterTyped<'a, K: DictionaryKey, V: DictValue> { } impl<'a, K: DictionaryKey, V: DictValue> DictionaryValuesIterTyped<'a, K, V> { - pub(super) unsafe fn new(keys: &'a PrimitiveArray, values: &'a V) -> Self { + pub(super) fn new(keys: &'a PrimitiveArray, values: &'a V) -> Self { + assert_eq!(keys.null_count(), 0); Self { keys, values, @@ -137,3 +138,68 @@ impl<'a, K: DictionaryKey, V: DictValue> DoubleEndedIterator } } } + +pub struct DictionaryIterTyped<'a, K: DictionaryKey, V: DictValue> { + keys: &'a PrimitiveArray, + values: &'a V, + index: usize, + end: usize, +} + +impl<'a, K: DictionaryKey, V: DictValue> DictionaryIterTyped<'a, K, V> { + pub(super) fn new(keys: &'a PrimitiveArray, values: &'a V) -> Self { + Self { + keys, + values, + index: 0, + end: keys.len(), + } + } +} + +impl<'a, K: DictionaryKey, V: DictValue> Iterator for DictionaryIterTyped<'a, K, V> { + type Item = Option>; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + unsafe { + if let Some(key) = self.keys.get_unchecked(old) { + let idx = key.as_usize(); + Some(Some(self.values.get_unchecked(idx))) + } else { + Some(None) + } + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl<'a, K: DictionaryKey, V: DictValue> TrustedLen for DictionaryIterTyped<'a, K, V> {} + +impl<'a, K: DictionaryKey, V: DictValue> DoubleEndedIterator for DictionaryIterTyped<'a, K, V> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + unsafe { + if let Some(key) = self.keys.get_unchecked(self.end) { + let idx = key.as_usize(); + Some(Some(self.values.get_unchecked(idx))) + } else { + Some(None) + } + } + } + } +} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index 959c1f5ec666..019e9a80a962 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -279,6 +279,10 @@ impl CategoricalChunked { self } + pub fn _with_fast_unique(self, toggle: bool) -> Self { + self.with_fast_unique(toggle) + } + /// Get a reference to the mapping of categorical types to the string values. pub fn get_rev_map(&self) -> &Arc { if let DataType::Categorical(Some(rev_map), _) | DataType::Enum(Some(rev_map), _) = diff --git a/crates/polars-core/src/utils/supertype.rs b/crates/polars-core/src/utils/supertype.rs index 7181c89d4885..47ecd9c13a4f 100644 --- a/crates/polars-core/src/utils/supertype.rs +++ b/crates/polars-core/src/utils/supertype.rs @@ -264,28 +264,43 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option { }, (dt, Unknown(kind)) => { match kind { + // numeric vs float|str -> always float|str UnknownKind::Float | UnknownKind::Int(_) if dt.is_float() | dt.is_string() => Some(dt.clone()), - UnknownKind::Float if dt.is_numeric() => Some(Unknown(UnknownKind::Float)), + UnknownKind::Float if dt.is_integer() => Some(Unknown(UnknownKind::Float)), + // Materialize float + UnknownKind::Float if dt.is_float() => Some(dt.clone()), + // Materialize str UnknownKind::Str if dt.is_string() | dt.is_enum() => Some(dt.clone()), + // Materialize str #[cfg(feature = "dtype-categorical")] UnknownKind::Str if dt.is_categorical() => { let Categorical(_, ord) = dt else { unreachable!()}; Some(Categorical(None, *ord)) }, + // Keep unknown dynam if dt.is_null() => Some(Unknown(*dynam)), + // Find integers sizes UnknownKind::Int(v) if dt.is_numeric() => { - let smallest_fitting_dtype = if dt.is_unsigned_integer() && v.is_positive() { - materialize_dyn_int_pos(*v).dtype() - } else { - materialize_smallest_dyn_int(*v).dtype() - }; - match dt { - UInt64 if smallest_fitting_dtype.is_signed_integer() => { - // Ensure we don't cast to float when dealing with dynamic literals - Some(Int64) - }, - _ => { - get_supertype(dt, &smallest_fitting_dtype) + // Both dyn int + if let Unknown(UnknownKind::Int(v_other)) = dt { + // Take the maximum value to ensure we bubble up the required minimal size. + Some(Unknown(UnknownKind::Int(std::cmp::max(*v, *v_other)))) + } + // dyn int vs number + else { + let smallest_fitting_dtype = if dt.is_unsigned_integer() && v.is_positive() { + materialize_dyn_int_pos(*v).dtype() + } else { + materialize_smallest_dyn_int(*v).dtype() + }; + match dt { + UInt64 if smallest_fitting_dtype.is_signed_integer() => { + // Ensure we don't cast to float when dealing with dynamic literals + Some(Int64) + }, + _ => { + get_supertype(dt, &smallest_fitting_dtype) + } } } } diff --git a/crates/polars-io/src/csv/read/read_impl/batched_read.rs b/crates/polars-io/src/csv/read/read_impl/batched_read.rs index 9098d255c6a2..64e165844e7a 100644 --- a/crates/polars-io/src/csv/read/read_impl/batched_read.rs +++ b/crates/polars-io/src/csv/read/read_impl/batched_read.rs @@ -7,6 +7,7 @@ use polars_core::frame::DataFrame; use polars_core::schema::SchemaRef; use polars_core::POOL; use polars_error::PolarsResult; +use polars_utils::sync::SyncPtr; use polars_utils::IdxSize; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; @@ -54,6 +55,8 @@ pub(crate) fn get_offsets( } } +/// Reads bytes from `file` to `buf` and returns pointers into `buf` that can be parsed. +/// TODO! this can be implemented without copying by pointing in the memmapped file. struct ChunkReader<'a> { file: &'a File, buf: Vec, @@ -109,18 +112,23 @@ impl<'a> ChunkReader<'a> { self.buf_end = 0; } - fn return_slice(&self, start: usize, end: usize) -> (usize, usize) { + fn return_slice(&self, start: usize, end: usize) -> (SyncPtr, usize) { let slice = &self.buf[start..end]; let len = slice.len(); - (slice.as_ptr() as usize, len) + (slice.as_ptr().into(), len) } - fn get_buf(&self) -> (usize, usize) { + fn get_buf_remaining(&self) -> (SyncPtr, usize) { let slice = &self.buf[self.buf_end..]; let len = slice.len(); - (slice.as_ptr() as usize, len) + (slice.as_ptr().into(), len) } + // Get next `n` offset positions. Where `n` is number of chunks. + + // This returns pointers into slices into `buf` + // we must process the slices before the next call + // as that will overwrite the slices fn read(&mut self, n: usize) -> bool { self.reslice(); @@ -267,7 +275,7 @@ pub struct BatchedCsvReaderRead<'a> { chunk_size: usize, finished: bool, file_chunk_reader: ChunkReader<'a>, - file_chunks: Vec<(usize, usize)>, + file_chunks: Vec<(SyncPtr, usize)>, projection: Vec, starting_point_offset: Option, row_index: Option, @@ -292,6 +300,7 @@ pub struct BatchedCsvReaderRead<'a> { } // impl<'a> BatchedCsvReaderRead<'a> { + /// `n` number of batches. pub fn next_batches(&mut self, n: usize) -> PolarsResult>> { if n == 0 || self.finished { return Ok(None); @@ -320,7 +329,8 @@ impl<'a> BatchedCsvReaderRead<'a> { // ensure we process the final slice as well. if self.file_chunk_reader.finished && self.file_chunks.len() < n { // get the final slice - self.file_chunks.push(self.file_chunk_reader.get_buf()); + self.file_chunks + .push(self.file_chunk_reader.get_buf_remaining()); self.finished = true } @@ -333,7 +343,7 @@ impl<'a> BatchedCsvReaderRead<'a> { self.file_chunks .par_iter() .map(|(ptr, len)| { - let chunk = unsafe { std::slice::from_raw_parts(*ptr as *const u8, *len) }; + let chunk = unsafe { std::slice::from_raw_parts(ptr.get(), *len) }; let stop_at_n_bytes = chunk.len(); let mut df = read_chunk( chunk, diff --git a/crates/polars-lazy/src/dsl/functions.rs b/crates/polars-lazy/src/dsl/functions.rs index a08559a9d14d..b488df382fbc 100644 --- a/crates/polars-lazy/src/dsl/functions.rs +++ b/crates/polars-lazy/src/dsl/functions.rs @@ -4,115 +4,39 @@ //! use polars_core::prelude::*; pub use polars_plan::dsl::functions::*; +use polars_plan::prelude::UnionArgs; use rayon::prelude::*; use crate::prelude::*; pub(crate) fn concat_impl>( inputs: L, - rechunk: bool, - parallel: bool, - from_partitioned_ds: bool, - convert_supertypes: bool, + args: UnionArgs, ) -> PolarsResult { let mut inputs = inputs.as_ref().to_vec(); - let mut lf = std::mem::take( + let lf = std::mem::take( inputs .get_mut(0) .ok_or_else(|| polars_err!(NoData: "empty container given"))?, ); let mut opt_state = lf.opt_state; - let options = UnionOptions { - parallel, - from_partitioned_ds, - rechunk, - ..Default::default() - }; - - let lf = match &mut lf.logical_plan { - // reuse the same union - DslPlan::Union { - inputs: existing_inputs, - options: opts, - } if opts == &options => { - for lf in &mut inputs[1..] { - // ensure we enable file caching if any lf has it enabled - opt_state.file_caching |= lf.opt_state.file_caching; - let lp = std::mem::take(&mut lf.logical_plan); - existing_inputs.push(lp) - } - lf - }, - _ => { - let mut lps = Vec::with_capacity(inputs.len()); - lps.push(lf.logical_plan); - - for lf in &mut inputs[1..] { - // ensure we enable file caching if any lf has it enabled - opt_state.file_caching |= lf.opt_state.file_caching; - let lp = std::mem::take(&mut lf.logical_plan); - lps.push(lp) - } - - let lp = DslPlan::Union { - inputs: lps, - options, - }; - let mut lf = LazyFrame::from(lp); - lf.opt_state = opt_state; - - lf - }, - }; - if convert_supertypes { - let DslPlan::Union { - mut inputs, - options, - } = lf.logical_plan - else { - unreachable!() - }; - let mut schema = inputs[0].compute_schema()?.as_ref().clone(); - - let mut changed = false; - for input in inputs[1..].iter() { - changed |= schema.to_supertype(input.compute_schema()?.as_ref())?; - } - - let mut placeholder = DslPlan::default(); - if changed { - let mut exprs = vec![]; - for input in &mut inputs { - std::mem::swap(input, &mut placeholder); - let input_schema = placeholder.compute_schema()?; - - exprs.clear(); - let to_cast = input_schema.iter().zip(schema.iter_dtypes()).flat_map( - |((left_name, left_type), st)| { - if left_type != st { - Some(col(left_name.as_ref()).cast(st.clone())) - } else { - None - } - }, - ); - exprs.extend(to_cast); - let mut lf = LazyFrame::from(placeholder); - if !exprs.is_empty() { - lf = lf.with_columns(exprs.as_slice()); - } + let mut lps = Vec::with_capacity(inputs.len()); + lps.push(lf.logical_plan); - placeholder = lf.logical_plan; - std::mem::swap(&mut placeholder, input); - } - } - Ok(LazyFrame::from(DslPlan::Union { inputs, options })) - } else { - Ok(lf) + for lf in &mut inputs[1..] { + // ensure we enable file caching if any lf has it enabled + opt_state.file_caching |= lf.opt_state.file_caching; + let lp = std::mem::take(&mut lf.logical_plan); + lps.push(lp) } + + let lp = DslPlan::Union { inputs: lps, args }; + let mut lf = LazyFrame::from(lp); + lf.opt_state = opt_state; + Ok(lf) } #[cfg(feature = "diagonal_concat")] @@ -216,32 +140,9 @@ pub fn concat_lf_horizontal>( Ok(lf) } -#[derive(Clone, Copy)] -pub struct UnionArgs { - pub parallel: bool, - pub rechunk: bool, - pub to_supertypes: bool, -} - -impl Default for UnionArgs { - fn default() -> Self { - Self { - parallel: true, - rechunk: true, - to_supertypes: false, - } - } -} - /// Concat multiple [`LazyFrame`]s vertically. pub fn concat>(inputs: L, args: UnionArgs) -> PolarsResult { - concat_impl( - inputs, - args.rechunk, - args.parallel, - false, - args.to_supertypes, - ) + concat_impl(inputs, args) } /// Collect all [`LazyFrame`] computations. diff --git a/crates/polars-lazy/src/prelude.rs b/crates/polars-lazy/src/prelude.rs index bf672ccfd755..b986b5924d1b 100644 --- a/crates/polars-lazy/src/prelude.rs +++ b/crates/polars-lazy/src/prelude.rs @@ -13,6 +13,7 @@ pub use polars_plan::logical_plan::{ AnonymousScan, AnonymousScanArgs, AnonymousScanOptions, DslPlan, Literal, LiteralValue, Null, NULL, }; +pub use polars_plan::prelude::UnionArgs; pub(crate) use polars_plan::prelude::*; #[cfg(feature = "rolling_window")] pub use polars_time::{prelude::RollingOptions, Duration}; diff --git a/crates/polars-lazy/src/scan/csv.rs b/crates/polars-lazy/src/scan/csv.rs index 9999d5219ce5..86373dbfa0e3 100644 --- a/crates/polars-lazy/src/scan/csv.rs +++ b/crates/polars-lazy/src/scan/csv.rs @@ -387,6 +387,13 @@ impl LazyFileListReader for LazyCsvReader { fn concat_impl(&self, lfs: Vec) -> PolarsResult { // set to false, as the csv parser has full thread utilization - concat_impl(&lfs, self.rechunk(), false, true, false) + let args = UnionArgs { + rechunk: self.rechunk(), + parallel: false, + to_supertypes: false, + from_partitioned_ds: true, + ..Default::default() + }; + concat_impl(&lfs, args) } } diff --git a/crates/polars-lazy/src/scan/file_list_reader.rs b/crates/polars-lazy/src/scan/file_list_reader.rs index bc19aea8a7d5..70971c424f12 100644 --- a/crates/polars-lazy/src/scan/file_list_reader.rs +++ b/crates/polars-lazy/src/scan/file_list_reader.rs @@ -5,6 +5,7 @@ use polars_core::prelude::*; use polars_io::cloud::CloudOptions; use polars_io::utils::is_cloud_url; use polars_io::RowIndex; +use polars_plan::prelude::UnionArgs; use crate::prelude::*; @@ -83,7 +84,14 @@ pub trait LazyFileListReader: Clone { /// This method should not take into consideration [LazyFileListReader::n_rows] /// nor [LazyFileListReader::row_index]. fn concat_impl(&self, lfs: Vec) -> PolarsResult { - concat_impl(&lfs, self.rechunk(), true, true, false) + let args = UnionArgs { + rechunk: self.rechunk(), + parallel: true, + to_supertypes: false, + from_partitioned_ds: true, + ..Default::default() + }; + concat_impl(&lfs, args) } /// Get the final [LazyFrame]. diff --git a/crates/polars-ops/src/series/ops/cut.rs b/crates/polars-ops/src/series/ops/cut.rs index f1721ac2fd45..aa1025f1ed55 100644 --- a/crates/polars-ops/src/series/ops/cut.rs +++ b/crates/polars-ops/src/series/ops/cut.rs @@ -25,6 +25,9 @@ fn map_cats( PartialOrd::gt }; + // Ensure fast unique is only set if all labels were seen. + let mut label_has_value = vec![false; 1 + sorted_breaks.len()]; + if include_breaks { // This is to replicate the behavior of the old buggy version that only worked on series and // returned a dataframe. That included a column of the right endpoint of the interval. So we @@ -33,8 +36,11 @@ fn map_cats( let mut brk_vals = PrimitiveChunkedBuilder::::new("brk", s.len()); s_iter .map(|opt| { - opt.filter(|x| !x.is_nan()) - .map(|x| sorted_breaks.partition_point(|v| op(&x, v))) + opt.filter(|x| !x.is_nan()).map(|x| { + let pt = sorted_breaks.partition_point(|v| op(&x, v)); + unsafe { *label_has_value.get_unchecked_mut(pt) = true }; + pt + }) }) .for_each(|idx| match idx { None => { @@ -47,17 +53,23 @@ fn map_cats( }, }); - let outvals = vec![brk_vals.finish().into_series(), bld.finish().into_series()]; + let outvals = vec![ + brk_vals.finish().into_series(), + bld.finish() + ._with_fast_unique(label_has_value.iter().all(bool::clone)) + .into_series(), + ]; Ok(StructChunked::new(&out_name, &outvals)?.into_series()) } else { Ok(bld .drain_iter_and_finish(s_iter.map(|opt| { - opt.filter(|x| !x.is_nan()).map(|x| unsafe { - labels - .get_unchecked(sorted_breaks.partition_point(|v| op(&x, v))) - .as_str() + opt.filter(|x| !x.is_nan()).map(|x| { + let pt = sorted_breaks.partition_point(|v| op(&x, v)); + unsafe { *label_has_value.get_unchecked_mut(pt) = true }; + unsafe { labels.get_unchecked(pt).as_str() } }) })) + ._with_fast_unique(label_has_value.iter().all(bool::clone)) .into_series()) } } @@ -145,3 +157,31 @@ pub fn qcut( map_cats(&s, &cut_labels, &qbreaks, left_closed, include_breaks) } + +mod test { + #[test] + fn test_map_cats_fast_unique() { + // This test is here to check the fast unique flag is set when it can be + // as it is not visible to Python. + use polars_core::prelude::*; + + use super::map_cats; + + let s = Series::new("x", &[1, 2, 3, 4, 5]); + + let labels = &["a", "b", "c"].map(str::to_owned); + let breaks = &[2.0, 4.0]; + let left_closed = false; + + let include_breaks = false; + let out = map_cats(&s, labels, breaks, left_closed, include_breaks).unwrap(); + let out = out.categorical().unwrap(); + assert!(out._can_fast_unique()); + + let include_breaks = true; + let out = map_cats(&s, labels, breaks, left_closed, include_breaks).unwrap(); + let out = out.struct_().unwrap().fields()[1].clone(); + let out = out.categorical().unwrap(); + assert!(out._can_fast_unique()); + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/rolling.rs b/crates/polars-plan/src/dsl/function_expr/rolling.rs index d3bf85877cd8..f1ae64c5f792 100644 --- a/crates/polars-plan/src/dsl/function_expr/rolling.rs +++ b/crates/polars-plan/src/dsl/function_expr/rolling.rs @@ -75,10 +75,6 @@ fn convert<'a>( let mut by = ss[1].clone(); by = by.rechunk(); - polars_ensure!( - options.weights.is_none(), - ComputeError: "`weights` is not supported in 'rolling by' expression" - ); let (by, tz) = match by.dtype() { DataType::Datetime(tu, tz) => (by.cast(&DataType::Datetime(*tu, None))?, tz), DataType::Date => ( @@ -116,12 +112,12 @@ fn convert<'a>( let options = RollingOptionsImpl { window_size: options.window_size, min_periods: options.min_periods, - weights: None, + weights: options.weights, center: options.center, by: Some(by_values), tu: Some(tu), tz: tz.as_ref(), - closed_window: options.closed_window.or(Some(ClosedWindow::Right)), + closed_window: options.closed_window, fn_params: options.fn_params.clone(), }; @@ -130,7 +126,7 @@ fn convert<'a>( } pub(super) fn rolling_min(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_min(options.clone().try_into()?) + s.rolling_min(options.into()) } pub(super) fn rolling_min_by(s: &[Series], options: RollingOptions) -> PolarsResult { @@ -138,7 +134,7 @@ pub(super) fn rolling_min_by(s: &[Series], options: RollingOptions) -> PolarsRes } pub(super) fn rolling_max(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_max(options.clone().try_into()?) + s.rolling_max(options.into()) } pub(super) fn rolling_max_by(s: &[Series], options: RollingOptions) -> PolarsResult { @@ -146,7 +142,7 @@ pub(super) fn rolling_max_by(s: &[Series], options: RollingOptions) -> PolarsRes } pub(super) fn rolling_mean(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_mean(options.clone().try_into()?) + s.rolling_mean(options.into()) } pub(super) fn rolling_mean_by(s: &[Series], options: RollingOptions) -> PolarsResult { @@ -154,7 +150,7 @@ pub(super) fn rolling_mean_by(s: &[Series], options: RollingOptions) -> PolarsRe } pub(super) fn rolling_sum(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_sum(options.clone().try_into()?) + s.rolling_sum(options.into()) } pub(super) fn rolling_sum_by(s: &[Series], options: RollingOptions) -> PolarsResult { @@ -162,7 +158,7 @@ pub(super) fn rolling_sum_by(s: &[Series], options: RollingOptions) -> PolarsRes } pub(super) fn rolling_quantile(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_quantile(options.clone().try_into()?) + s.rolling_quantile(options.into()) } pub(super) fn rolling_quantile_by(s: &[Series], options: RollingOptions) -> PolarsResult { @@ -174,7 +170,7 @@ pub(super) fn rolling_quantile_by(s: &[Series], options: RollingOptions) -> Pola } pub(super) fn rolling_var(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_var(options.clone().try_into()?) + s.rolling_var(options.into()) } pub(super) fn rolling_var_by(s: &[Series], options: RollingOptions) -> PolarsResult { @@ -182,7 +178,7 @@ pub(super) fn rolling_var_by(s: &[Series], options: RollingOptions) -> PolarsRes } pub(super) fn rolling_std(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_std(options.clone().try_into()?) + s.rolling_std(options.into()) } pub(super) fn rolling_std_by(s: &[Series], options: RollingOptions) -> PolarsResult { diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 235fccf905d7..ac3439d20e3b 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1254,9 +1254,6 @@ impl Expr { false, ) } else { - if !options.window_size.parsed_int { - panic!("if dynamic windows are used in a rolling aggregation, the 'by' argument must be set") - } self.apply_private(FunctionExpr::RollingExpr(rolling_function(options))) } } diff --git a/crates/polars-plan/src/logical_plan/conversion/convert_utils.rs b/crates/polars-plan/src/logical_plan/conversion/convert_utils.rs new file mode 100644 index 000000000000..db7c591d16c6 --- /dev/null +++ b/crates/polars-plan/src/logical_plan/conversion/convert_utils.rs @@ -0,0 +1,44 @@ +use super::*; + +pub(super) fn convert_st_union( + inputs: &mut [Node], + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> PolarsResult<()> { + let mut schema = (**lp_arena.get(inputs[0]).schema(lp_arena)).clone(); + + let mut changed = false; + for input in inputs[1..].iter() { + let schema_other = lp_arena.get(*input).schema(lp_arena); + changed |= schema.to_supertype(schema_other.as_ref())?; + } + + if changed { + for input in inputs { + let mut exprs = vec![]; + let input_schema = lp_arena.get(*input).schema(lp_arena); + + let to_cast = input_schema.iter().zip(schema.iter_dtypes()).flat_map( + |((left_name, left_type), st)| { + if left_type != st { + Some(col(left_name.as_ref()).cast(st.clone())) + } else { + None + } + }, + ); + exprs.extend(to_cast); + + if !exprs.is_empty() { + let expr = to_expr_irs(exprs, expr_arena); + let lp = IRBuilder::new(*input, expr_arena, lp_arena) + .with_columns(expr, Default::default()) + .build(); + + let node = lp_arena.add(lp); + *input = node + } + } + } + Ok(()) +} diff --git a/crates/polars-plan/src/logical_plan/conversion/dsl_plan_to_ir_plan.rs b/crates/polars-plan/src/logical_plan/conversion/dsl_to_ir.rs similarity index 98% rename from crates/polars-plan/src/logical_plan/conversion/dsl_plan_to_ir_plan.rs rename to crates/polars-plan/src/logical_plan/conversion/dsl_to_ir.rs index c0b2f3f3f571..2d0aa4b884b8 100644 --- a/crates/polars-plan/src/logical_plan/conversion/dsl_plan_to_ir_plan.rs +++ b/crates/polars-plan/src/logical_plan/conversion/dsl_to_ir.rs @@ -146,12 +146,18 @@ pub fn to_alp_impl( options, predicate: None, }, - DslPlan::Union { inputs, options } => { - let inputs = inputs + DslPlan::Union { inputs, args } => { + let mut inputs = inputs .into_iter() .map(|lp| to_alp_impl(lp, expr_arena, lp_arena, convert)) - .collect::>() + .collect::>>() .map_err(|e| e.context(failed_input!(vertical concat)))?; + + if args.to_supertypes { + convert_utils::convert_st_union(&mut inputs, lp_arena, expr_arena) + .map_err(|e| e.context(failed_input!(vertical concat)))?; + } + let options = args.into(); IR::Union { inputs, options } }, DslPlan::HConcat { diff --git a/crates/polars-plan/src/logical_plan/conversion/expr_to_expr_ir.rs b/crates/polars-plan/src/logical_plan/conversion/expr_to_ir.rs similarity index 100% rename from crates/polars-plan/src/logical_plan/conversion/expr_to_expr_ir.rs rename to crates/polars-plan/src/logical_plan/conversion/expr_to_ir.rs diff --git a/crates/polars-plan/src/logical_plan/conversion/mod.rs b/crates/polars-plan/src/logical_plan/conversion/mod.rs index 0c451394be4d..f87a38964149 100644 --- a/crates/polars-plan/src/logical_plan/conversion/mod.rs +++ b/crates/polars-plan/src/logical_plan/conversion/mod.rs @@ -1,5 +1,6 @@ -mod dsl_plan_to_ir_plan; -mod expr_to_expr_ir; +mod convert_utils; +mod dsl_to_ir; +mod expr_to_ir; mod ir_to_dsl; #[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))] mod scans; @@ -7,8 +8,8 @@ mod stack_opt; use std::borrow::Cow; -pub use dsl_plan_to_ir_plan::*; -pub use expr_to_expr_ir::*; +pub use dsl_to_ir::*; +pub use expr_to_ir::*; pub use ir_to_dsl::*; use polars_core::prelude::*; use polars_utils::vec::ConvertVec; @@ -53,12 +54,15 @@ impl IR { }, #[cfg(feature = "python")] IR::PythonScan { options, .. } => DslPlan::PythonScan { options }, - IR::Union { inputs, options } => { + IR::Union { inputs, .. } => { let inputs = inputs .into_iter() .map(|node| convert_to_lp(node, lp_arena)) .collect(); - DslPlan::Union { inputs, options } + DslPlan::Union { + inputs, + args: Default::default(), + } }, IR::HConcat { inputs, diff --git a/crates/polars-plan/src/logical_plan/format.rs b/crates/polars-plan/src/logical_plan/format.rs index 7b0930a3b8e9..70c06d095300 100644 --- a/crates/polars-plan/src/logical_plan/format.rs +++ b/crates/polars-plan/src/logical_plan/format.rs @@ -1,6 +1,6 @@ use std::borrow::Cow; use std::fmt; -use std::fmt::{Debug, Display, Formatter, Write}; +use std::fmt::{Debug, Display, Formatter}; use std::path::PathBuf; use polars_core::prelude::AnyValue; @@ -81,14 +81,16 @@ impl DslPlan { options.n_rows, ) }, - Union { inputs, options } => { - let mut name = String::new(); - let name = if let Some(slice) = options.slice { - write!(name, "SLICED UNION: {slice:?}")?; - name.as_str() - } else { - "UNION" - }; + Union { inputs, .. } => { + // let mut name = String::new(); + // THIS is commented out, but must be restored once we format IR's + // let name = if let Some(slice) = options.slice { + // write!(name, "SLICED UNION: {slice:?}")?; + // name.as_str() + // } else { + // "UNION" + // }; + let name = "UNION"; // 3 levels of indentation // - 0 => UNION ... END UNION // - 1 => PLAN 0, PLAN 1, ... PLAN N diff --git a/crates/polars-plan/src/logical_plan/mod.rs b/crates/polars-plan/src/logical_plan/mod.rs index f89ff0ab1eab..749ff33e5006 100644 --- a/crates/polars-plan/src/logical_plan/mod.rs +++ b/crates/polars-plan/src/logical_plan/mod.rs @@ -153,9 +153,10 @@ pub enum DslPlan { input: Arc, function: DslFunction, }, + /// Vertical concatenation Union { inputs: Vec, - options: UnionOptions, + args: UnionArgs, }, /// Horizontal concatenation of multiple plans HConcat { @@ -196,7 +197,7 @@ impl Clone for DslPlan { Self::Sort {input,by_column, slice, sort_options } => Self::Sort { input: input.clone(), by_column: by_column.clone(), slice: slice.clone(), sort_options: sort_options.clone() }, Self::Slice { input, offset, len } => Self::Slice { input: input.clone(), offset: offset.clone(), len: len.clone() }, Self::MapFunction { input, function } => Self::MapFunction { input: input.clone(), function: function.clone() }, - Self::Union { inputs, options } => Self::Union { inputs: inputs.clone(), options: options.clone() }, + Self::Union { inputs, args} => Self::Union { inputs: inputs.clone(), args: args.clone() }, Self::HConcat { inputs, schema, options } => Self::HConcat { inputs: inputs.clone(), schema: schema.clone(), options: options.clone() }, Self::ExtContext { input, contexts, } => Self::ExtContext { input: input.clone(), contexts: contexts.clone() }, Self::Sink { input, payload } => Self::Sink { input: input.clone(), payload: payload.clone() }, diff --git a/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs b/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs index bcc5c243c672..353807ee9095 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs @@ -7,6 +7,7 @@ fn get_upper_projections( lp_arena: &Arena, expr_arena: &Arena, names_scratch: &mut Vec, + found_required_columns: &mut bool, ) -> bool { let parent = lp_arena.get(parent); @@ -16,6 +17,7 @@ fn get_upper_projections( SimpleProjection { columns, .. } => { let iter = columns.iter_names().map(|s| ColumnName::from(s.as_str())); names_scratch.extend(iter); + *found_required_columns = true; false }, Filter { predicate, .. } => { @@ -201,7 +203,7 @@ pub(super) fn set_cache_states( v.parents.push(frame.parent); v.cache_nodes.push(frame.current); - let mut found_columns = false; + let mut found_required_columns = false; for parent_node in frame.parent.into_iter().flatten() { let keep_going = get_upper_projections( @@ -209,9 +211,9 @@ pub(super) fn set_cache_states( lp_arena, expr_arena, &mut names_scratch, + &mut found_required_columns, ); if !names_scratch.is_empty() { - found_columns = true; v.names_union.extend(names_scratch.drain(..)); } // We stop early as we want to find the first projection node above the cache. @@ -241,7 +243,7 @@ pub(super) fn set_cache_states( // There was no explicit projection and we must take // all columns - if !found_columns { + if !found_required_columns { let schema = lp.schema(lp_arena); v.names_union.extend( schema diff --git a/crates/polars-plan/src/logical_plan/options.rs b/crates/polars-plan/src/logical_plan/options.rs index 20f7fb1183ff..1b561f7bd5ad 100644 --- a/crates/polars-plan/src/logical_plan/options.rs +++ b/crates/polars-plan/src/logical_plan/options.rs @@ -274,3 +274,40 @@ impl Default for ProjectionOptions { } } } + +// Arguments given to `concat`. Differs from `UnionOptions` as the latter is IR state. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct UnionArgs { + pub parallel: bool, + pub rechunk: bool, + pub to_supertypes: bool, + pub diagonal: bool, + // If it is a union from a scan over multiple files. + pub from_partitioned_ds: bool, +} + +impl Default for UnionArgs { + fn default() -> Self { + Self { + parallel: true, + rechunk: true, + to_supertypes: false, + diagonal: false, + from_partitioned_ds: false, + } + } +} + +impl From for UnionOptions { + fn from(args: UnionArgs) -> Self { + UnionOptions { + slice: None, + parallel: args.parallel, + rows: (None, 0), + from_partitioned_ds: args.from_partitioned_ds, + flattened_by_opt: false, + rechunk: args.rechunk, + } + } +} diff --git a/crates/polars-plan/src/logical_plan/schema.rs b/crates/polars-plan/src/logical_plan/schema.rs index 7d7044e498e1..6c4629a80cb0 100644 --- a/crates/polars-plan/src/logical_plan/schema.rs +++ b/crates/polars-plan/src/logical_plan/schema.rs @@ -12,6 +12,10 @@ use super::hive::HivePartitions; use crate::prelude::*; impl DslPlan { + // Warning! This should not be used on the DSL internally. + // All schema resolving should be done during conversion to [`IR`]. + + /// Compute the schema. This requires conversion to [`IR`] and type-resolving. pub fn compute_schema(&self) -> PolarsResult { let opt_state = OptState { eager: true, diff --git a/crates/polars-plan/src/logical_plan/tree_format.rs b/crates/polars-plan/src/logical_plan/tree_format.rs index f64c4dfc3f61..5a227a5660db 100644 --- a/crates/polars-plan/src/logical_plan/tree_format.rs +++ b/crates/polars-plan/src/logical_plan/tree_format.rs @@ -163,14 +163,16 @@ impl<'a> TreeFmtNode<'a> { vec![] }, ), - NL(h, Union { inputs, options }) => ND( + NL(h, Union { inputs, .. }) => ND( wh( h, - &(if let Some(slice) = options.slice { - format!("SLICED UNION: {slice:?}") - } else { - "UNION".to_string() - }), + // THis is commented out, but must be restored when we convert to IR's. + // &(if let Some(slice) = options.slice { + // format!("SLICED UNION: {slice:?}") + // } else { + // "UNION".to_string() + // }), + "UNION", ), inputs .iter() diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 22596b0f5cf9..a2655caf7342 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -13,8 +13,8 @@ use sqlparser::ast::ExactNumberInfo; use sqlparser::ast::{ ArrayAgg, ArrayElemTypeDef, BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, - JoinConstraint, OrderByExpr, Query as Subquery, SelectItem, TimezoneInfo, TrimWhereField, - UnaryOperator, Value as SQLValue, + JoinConstraint, ObjectName, OrderByExpr, Query as Subquery, SelectItem, TimezoneInfo, + TrimWhereField, UnaryOperator, Value as SQLValue, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; @@ -24,41 +24,53 @@ use crate::SQLContext; pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult { Ok(match data_type { + // --------------------------------- + // array/list + // --------------------------------- SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_type)) | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_type)) => { DataType::List(Box::new(map_sql_polars_datatype(inner_type)?)) }, - #[cfg(feature = "dtype-decimal")] - SQLDataType::Dec(info) | SQLDataType::Decimal(info) | SQLDataType::Numeric(info) => { - match *info { - ExactNumberInfo::PrecisionAndScale(p, s) => { - DataType::Decimal(Some(p as usize), Some(s as usize)) - }, - ExactNumberInfo::Precision(p) => DataType::Decimal(Some(p as usize), Some(0)), - ExactNumberInfo::None => DataType::Decimal(Some(38), Some(9)), - } - }, - SQLDataType::BigInt(_) => DataType::Int64, - SQLDataType::Boolean => DataType::Boolean, + + // --------------------------------- + // binary + // --------------------------------- SQLDataType::Bytea | SQLDataType::Bytes(_) | SQLDataType::Binary(_) | SQLDataType::Blob(_) | SQLDataType::Varbinary(_) => DataType::Binary, - SQLDataType::Char(_) - | SQLDataType::CharVarying(_) - | SQLDataType::Character(_) - | SQLDataType::CharacterVarying(_) - | SQLDataType::Clob(_) - | SQLDataType::String(_) - | SQLDataType::Text - | SQLDataType::Uuid - | SQLDataType::Varchar(_) => DataType::String, - SQLDataType::Date => DataType::Date, - SQLDataType::Double - | SQLDataType::DoublePrecision - | SQLDataType::Float8 - | SQLDataType::Float64 => DataType::Float64, + + // --------------------------------- + // boolean + // --------------------------------- + SQLDataType::Boolean | SQLDataType::Bool => DataType::Boolean, + + // --------------------------------- + // signed integer + // --------------------------------- + SQLDataType::Int(_) | SQLDataType::Integer(_) => DataType::Int32, + SQLDataType::Int2(_) | SQLDataType::SmallInt(_) => DataType::Int16, + SQLDataType::Int4(_) | SQLDataType::MediumInt(_) => DataType::Int32, + SQLDataType::Int8(_) | SQLDataType::BigInt(_) => DataType::Int64, + SQLDataType::TinyInt(_) => DataType::Int8, + + // --------------------------------- + // unsigned integer: the following do not map to PostgreSQL types/syntax, but + // are enabled for wider compatibility (eg: "CAST(col AS BIGINT UNSIGNED)"). + // --------------------------------- + SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => DataType::UInt32, + SQLDataType::UnsignedInt2(_) | SQLDataType::UnsignedSmallInt(_) => DataType::UInt16, + SQLDataType::UnsignedInt4(_) | SQLDataType::UnsignedMediumInt(_) => DataType::UInt32, + SQLDataType::UnsignedInt8(_) | SQLDataType::UnsignedBigInt(_) => DataType::UInt64, + SQLDataType::UnsignedTinyInt(_) => DataType::UInt8, // see also: "custom" types below + + // --------------------------------- + // float + // --------------------------------- + SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => { + DataType::Float64 + }, SQLDataType::Float(n_bytes) => match n_bytes { Some(n) if (1u64..=24u64).contains(n) => DataType::Float32, Some(n) if (25u64..=53u64).contains(n) => DataType::Float64, @@ -68,12 +80,26 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult DataType::Float64, }, SQLDataType::Float4 | SQLDataType::Real => DataType::Float32, - SQLDataType::Int(_) | SQLDataType::Integer(_) => DataType::Int32, - SQLDataType::Int2(_) => DataType::Int16, - SQLDataType::Int4(_) => DataType::Int32, - SQLDataType::Int8(_) => DataType::Int64, + + // --------------------------------- + // decimal + // --------------------------------- + #[cfg(feature = "dtype-decimal")] + SQLDataType::Dec(info) | SQLDataType::Decimal(info) | SQLDataType::Numeric(info) => { + match *info { + ExactNumberInfo::PrecisionAndScale(p, s) => { + DataType::Decimal(Some(p as usize), Some(s as usize)) + }, + ExactNumberInfo::Precision(p) => DataType::Decimal(Some(p as usize), Some(0)), + ExactNumberInfo::None => DataType::Decimal(Some(38), Some(9)), + } + }, + + // --------------------------------- + // temporal + // --------------------------------- + SQLDataType::Date => DataType::Date, SQLDataType::Interval => DataType::Duration(TimeUnit::Microseconds), - SQLDataType::SmallInt(_) => DataType::Int16, SQLDataType::Time(_, tz) => match tz { TimezoneInfo::None => DataType::Time, _ => { @@ -97,16 +123,41 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult DataType::Int8, - SQLDataType::UnsignedBigInt(_) => DataType::UInt64, - SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => DataType::UInt32, - SQLDataType::UnsignedInt2(_) => DataType::UInt16, - SQLDataType::UnsignedInt4(_) => DataType::UInt32, - SQLDataType::UnsignedInt8(_) => DataType::UInt64, - SQLDataType::UnsignedSmallInt(_) => DataType::UInt16, - SQLDataType::UnsignedTinyInt(_) => DataType::UInt8, - _ => polars_bail!(ComputeError: "SQL datatype {:?} is not yet supported", data_type), + // --------------------------------- + // string + // --------------------------------- + SQLDataType::Char(_) + | SQLDataType::CharVarying(_) + | SQLDataType::Character(_) + | SQLDataType::CharacterVarying(_) + | SQLDataType::Clob(_) + | SQLDataType::String(_) + | SQLDataType::Text + | SQLDataType::Uuid + | SQLDataType::Varchar(_) => DataType::String, + + // --------------------------------- + // custom + // --------------------------------- + SQLDataType::Custom(ObjectName(idents), _) => match idents.as_slice() { + [Ident { value, .. }] => match value.to_lowercase().as_str() { + // these integer types are not supported by the PostgreSQL core distribution, + // but they ARE available via `pguint` (https://github.com/petere/pguint), an + // extension maintained by one of the PostgreSQL core developers. + "uint1" => DataType::UInt8, + "uint2" => DataType::UInt16, + "uint4" | "uint" => DataType::UInt32, + "uint8" => DataType::UInt64, + // `pguint` also provides a 1 byte (8bit) integer type alias + "int1" => DataType::Int8, + _ => { + polars_bail!(ComputeError: "SQL datatype {:?} is not currently supported", value) + }, + }, + _ => polars_bail!(ComputeError: "SQL datatype {:?} is not currently supported", idents), + }, + _ => polars_bail!(ComputeError: "SQL datatype {:?} is not currently supported", data_type), }) } @@ -500,7 +551,7 @@ impl SQLExprVisitor<'_> { return Ok(expr.str().json_decode(None, None)); } let polars_type = map_sql_polars_datatype(data_type)?; - Ok(expr.cast(polars_type)) + Ok(expr.strict_cast(polars_type)) } /// Visit a SQL literal. diff --git a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs index 4dcdb06433f8..1e6eb024919d 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs @@ -47,9 +47,8 @@ where let arr = ca.downcast_iter().next().unwrap(); // "5i" is a window size of 5, e.g. fixed - let arr = if options.window_size.parsed_int { + let arr = if options.by.is_none() { let options: RollingOptionsFixedWindow = options.try_into()?; - Ok(match ca.null_count() { 0 => rolling_agg_fn( arr.values().as_slice(), @@ -69,24 +68,20 @@ where ), }) } else { + let options: RollingOptionsDynamicWindow = options.try_into()?; if arr.null_count() > 0 { polars_bail!(InvalidOperation: "'Expr.rolling_*(..., by=...)' not yet supported for series with null values, consider using 'DataFrame.rolling' or 'Expr.rolling'") } let values = arr.values().as_slice(); - let duration = options.window_size; - polars_ensure!(duration.duration_ns() > 0 && !duration.negative, ComputeError:"window size should be strictly positive"); - let tu = options.tu.unwrap(); - let by = options.by.unwrap(); - let closed_window = options.closed_window.expect("closed window must be set"); - let func = rolling_agg_fn_dynamic.expect( - "'Expr.rolling_*(..., by=...)' not yet supported for this expression, consider using 'DataFrame.rolling' or 'Expr.rolling'", - ); + let tu = options.tu.expect("time_unit was set in `convert` function"); + let by = options.by; + let func = rolling_agg_fn_dynamic.expect("rolling_agg_fn_dynamic must have been passed"); func( values, - duration, + options.window_size, by, - closed_window, + options.closed_window, options.min_periods, tu, options.tz, diff --git a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs index 07feca0a5a4c..d5ae53e1459f 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs @@ -80,32 +80,19 @@ pub struct RollingOptionsImpl<'a> { pub fn_params: DynArgs, } -impl TryFrom for RollingOptionsImpl<'static> { - type Error = PolarsError; - - fn try_from(options: RollingOptions) -> PolarsResult { - let window_size = options.window_size; - assert!( - window_size.parsed_int, - "should be fixed integer window size at this point" - ); - polars_ensure!( - options.closed_window.is_none(), - InvalidOperation: "`closed_window` is not supported for fixed window size rolling aggregations, \ - consider using DataFrame.rolling for greater flexibility", - ); - - Ok(RollingOptionsImpl { - window_size, +impl From for RollingOptionsImpl<'static> { + fn from(options: RollingOptions) -> Self { + RollingOptionsImpl { + window_size: options.window_size, min_periods: options.min_periods, weights: options.weights, center: options.center, by: None, tu: None, tz: None, - closed_window: None, + closed_window: options.closed_window, fn_params: options.fn_params, - }) + } } } @@ -128,19 +115,17 @@ impl Default for RollingOptionsImpl<'static> { impl<'a> TryFrom> for RollingOptionsFixedWindow { type Error = PolarsError; fn try_from(options: RollingOptionsImpl<'a>) -> PolarsResult { - let window_size = options.window_size; - assert!( - window_size.parsed_int, - "should be fixed integer window size at this point" + polars_ensure!( + options.window_size.parsed_int, + InvalidOperation: "if `window_size` is a temporal window (e.g. '1d', '2h, ...), then the `by` argument must be passed" ); polars_ensure!( options.closed_window.is_none(), InvalidOperation: "`closed_window` is not supported for fixed window size rolling aggregations, \ consider using DataFrame.rolling for greater flexibility", ); - let window_size = window_size.nanoseconds() as usize; + let window_size = options.window_size.nanoseconds() as usize; check_input(window_size, options.min_periods)?; - Ok(RollingOptionsFixedWindow { window_size, min_periods: options.min_periods, @@ -159,3 +144,41 @@ fn check_input(window_size: usize, min_periods: usize) -> PolarsResult<()> { ); Ok(()) } + +#[derive(Clone)] +pub struct RollingOptionsDynamicWindow<'a> { + /// The length of the window. + pub window_size: Duration, + /// Amount of elements in the window that should be filled before computing a result. + pub min_periods: usize, + pub by: &'a [i64], + pub tu: Option, + pub tz: Option<&'a TimeZone>, + pub closed_window: ClosedWindow, + pub fn_params: DynArgs, +} + +impl<'a> TryFrom> for RollingOptionsDynamicWindow<'a> { + type Error = PolarsError; + fn try_from(options: RollingOptionsImpl<'a>) -> PolarsResult { + let duration = options.window_size; + polars_ensure!(duration.duration_ns() > 0 && !duration.negative, ComputeError:"window size should be strictly positive"); + polars_ensure!( + options.weights.is_none(), + InvalidOperation: "`weights` is not supported in 'rolling_*(..., by=...)' expression" + ); + polars_ensure!( + !options.window_size.parsed_int, + InvalidOperation: "if `by` argument is passed, then `window_size` must be a temporal window (e.g. '1d' or '2h', not '3i')" + ); + Ok(RollingOptionsDynamicWindow { + window_size: options.window_size, + min_periods: options.min_periods, + by: options.by.expect("by must have been set to get here"), + tu: options.tu, + tz: options.tz, + closed_window: options.closed_window.unwrap_or(ClosedWindow::Right), + fn_params: options.fn_params, + }) + } +} diff --git a/crates/polars-time/src/group_by/dynamic.rs b/crates/polars-time/src/group_by/dynamic.rs index 4a22d21f8a0c..e6df8579fc9b 100644 --- a/crates/polars-time/src/group_by/dynamic.rs +++ b/crates/polars-time/src/group_by/dynamic.rs @@ -152,6 +152,14 @@ impl Wrap<&DataFrame> { TimeUnit::Milliseconds, None, ), + Duration(tu) => { + let time_type_dt = Datetime(*tu, None); + let dt = time.cast(&time_type_dt).unwrap(); + let (out, by, gt) = + self.impl_rolling(dt, group_by, options, *tu, None, &time_type_dt)?; + let out = out.cast(&Duration(*tu)).unwrap(); + return Ok((out, by, gt)); + }, UInt32 | UInt64 | Int32 => { let time_type_dt = Datetime(TimeUnit::Nanoseconds, None); let dt = time.cast(&Int64).unwrap().cast(&time_type_dt).unwrap(); @@ -182,7 +190,7 @@ impl Wrap<&DataFrame> { }, dt => polars_bail!( ComputeError: - "expected any of the following dtypes: {{ Date, Datetime, Int32, Int64, UInt32, UInt64 }}, got {}", + "expected any of the following dtypes: {{ Date, Datetime, Duration, Int32, Int64, UInt32, UInt64 }}, got {}", dt ), }; diff --git a/crates/polars-utils/src/sync.rs b/crates/polars-utils/src/sync.rs index 3659130990b2..e4257ac17b82 100644 --- a/crates/polars-utils/src/sync.rs +++ b/crates/polars-utils/src/sync.rs @@ -13,11 +13,7 @@ impl SyncPtr { Self(ptr) } - /// # Safety - /// - /// This will make a pointer sync and send. - /// Ensure that you don't break aliasing rules. - pub unsafe fn from_const(ptr: *const T) -> Self { + pub fn from_const(ptr: *const T) -> Self { Self(ptr as *mut T) } @@ -43,3 +39,9 @@ impl SyncPtr { unsafe impl Sync for SyncPtr {} unsafe impl Send for SyncPtr {} + +impl From<*const T> for SyncPtr { + fn from(value: *const T) -> Self { + Self::from_const(value) + } +} diff --git a/crates/polars/tests/it/lazy/queries.rs b/crates/polars/tests/it/lazy/queries.rs index 8433d57cff74..e266a89485df 100644 --- a/crates/polars/tests/it/lazy/queries.rs +++ b/crates/polars/tests/it/lazy/queries.rs @@ -66,6 +66,68 @@ fn test_special_group_by_schemas() -> PolarsResult<()> { &[3, 5, 7, 9, 5] ); + // Duration index column - period have different units + let out = df + .clone() + .lazy() + .with_column( + col("a") + .cast(DataType::Duration(TimeUnit::Milliseconds)) + .set_sorted_flag(IsSorted::Ascending), + ) + .rolling( + col("a"), + [], + RollingGroupOptions { + period: Duration::parse("2ms"), + offset: Duration::parse("0ms"), + closed_window: ClosedWindow::Left, + ..Default::default() + }, + ) + .agg([col("b").sum().alias("sum")]) + .select([col("a"), col("sum")]) + .collect()?; + + assert_eq!( + out.column("sum")? + .i32()? + .into_no_null_iter() + .collect::>(), + &[3, 5, 7, 9, 5] + ); + + // Datetime index column - period have same units + let out = df + .clone() + .lazy() + .with_column( + col("a") + .cast(DataType::Datetime(TimeUnit::Milliseconds, None)) + .set_sorted_flag(IsSorted::Ascending), + ) + .rolling( + col("a"), + [], + RollingGroupOptions { + period: Duration::parse("2ms"), + offset: Duration::parse("0ms"), + closed_window: ClosedWindow::Left, + ..Default::default() + }, + ) + .agg([col("b").sum().alias("sum")]) + .select([col("a"), col("sum")]) + .collect()?; + + assert_eq!( + out.column("sum")? + .i32()? + .into_no_null_iter() + .collect::>(), + &[3, 5, 7, 9, 5] + ); + let out = df .lazy() .with_column(col("a").set_sorted_flag(IsSorted::Ascending)) diff --git a/docs/src/python/user-guide/expressions/user-defined-functions.py b/docs/src/python/user-guide/expressions/user-defined-functions.py index e0658b2d36a4..6edb63f5024a 100644 --- a/docs/src/python/user-guide/expressions/user-defined-functions.py +++ b/docs/src/python/user-guide/expressions/user-defined-functions.py @@ -16,7 +16,9 @@ # --8<-- [start:shift_map_batches] out = df.group_by("keys", maintain_order=True).agg( - pl.col("values").map_batches(lambda s: s.shift()).alias("shift_map_batches"), + pl.col("values") + .map_batches(lambda s: s.shift(), is_elementwise=True) + .alias("shift_map_batches"), pl.col("values").shift().alias("shift_expression"), ) print(out) @@ -25,7 +27,9 @@ # --8<-- [start:map_elements] out = df.group_by("keys", maintain_order=True).agg( - pl.col("values").map_elements(lambda s: s.shift()).alias("shift_map_elements"), + pl.col("values") + .map_elements(lambda s: s.shift(), return_dtype=pl.List(int)) + .alias("shift_map_elements"), pl.col("values").shift().alias("shift_expression"), ) print(out) diff --git a/docs/user-guide/expressions/user-defined-functions.md b/docs/user-guide/expressions/user-defined-functions.md index 882cc11c6ac1..7ced2fb0d50a 100644 --- a/docs/user-guide/expressions/user-defined-functions.md +++ b/docs/user-guide/expressions/user-defined-functions.md @@ -74,7 +74,7 @@ Let's try that out and see what we get: Ouch.. we clearly get the wrong results here. Group `"b"` even got a value from group `"a"` ðŸ˜ĩ. -This went horribly wrong, because the `map_batches` applies the function before we aggregate! So that means the whole column `[10, 7, 1`\] got shifted to `[null, 10, 7]` and was then aggregated. +This went horribly wrong because `map_batches` applied the function before aggregation, due to the `is_elementwise=True` parameter being provided. So that means the whole column `[10, 7, 1]` got shifted to `[null, 10, 7]` and was then aggregated. So my advice is to never use `map_batches` in the `group_by` context unless you know you need it and know what you are doing. diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 543b429a6c25..7a47082b4d29 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -6199,10 +6199,10 @@ def map_rows( Notes ----- - * The frame-level `apply` cannot track column names (as the UDF is a black-box - that may arbitrarily drop, rearrange, transform, or add new columns); if you - want to apply a UDF such that column names are preserved, you should use the - expression-level `apply` syntax instead. + * The frame-level `map_rows` cannot track column names (as the UDF is a + black-box that may arbitrarily drop, rearrange, transform, or add new + columns); if you want to apply a UDF such that column names are preserved, + you should use the expression-level `map_elements` syntax instead. * If your function is expensive and you don't want it to be called more than once for a given input, consider applying an `@lru_cache` decorator to it. diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 2e95a782c396..7053844c3f98 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -2046,7 +2046,7 @@ def top_k( This has time complexity: - .. math:: O(n + k \\log{}n - \frac{k}{2}) + .. math:: O(n + k \log{}n - \frac{k}{2}) Parameters ---------- @@ -2212,7 +2212,7 @@ def bottom_k( This has time complexity: - .. math:: O(n + k \\log{}n - \frac{k}{2}) + .. math:: O(n + k \log{}n - \frac{k}{2}) Parameters ---------- @@ -6238,10 +6238,11 @@ def rolling_min( │ 23 ┆ 2001-01-01 23:00:00 │ │ 24 ┆ 2001-01-02 00:00:00 │ └───────â”ī─────────────────────┘ + + Compute the rolling min with the temporal windows closed on the right (default) + >>> df_temporal.with_columns( - ... rolling_row_min=pl.col("index").rolling_min( - ... window_size="2h", by="date", closed="left" - ... ) + ... rolling_row_min=pl.col("index").rolling_min(window_size="2h", by="date") ... ) shape: (25, 3) ┌───────┮─────────────────────┮─────────────────┐ @@ -6249,17 +6250,17 @@ def rolling_min( │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[Ξs] ┆ u32 │ ╞═══════╩═════════════════════╩═════════════════╡ - │ 0 ┆ 2001-01-01 00:00:00 ┆ null │ + │ 0 ┆ 2001-01-01 00:00:00 ┆ 0 │ │ 1 ┆ 2001-01-01 01:00:00 ┆ 0 │ - │ 2 ┆ 2001-01-01 02:00:00 ┆ 0 │ - │ 3 ┆ 2001-01-01 03:00:00 ┆ 1 │ - │ 4 ┆ 2001-01-01 04:00:00 ┆ 2 │ + │ 2 ┆ 2001-01-01 02:00:00 ┆ 1 │ + │ 3 ┆ 2001-01-01 03:00:00 ┆ 2 │ + │ 4 ┆ 2001-01-01 04:00:00 ┆ 3 │ │ â€Ķ ┆ â€Ķ ┆ â€Ķ │ - │ 20 ┆ 2001-01-01 20:00:00 ┆ 18 │ - │ 21 ┆ 2001-01-01 21:00:00 ┆ 19 │ - │ 22 ┆ 2001-01-01 22:00:00 ┆ 20 │ - │ 23 ┆ 2001-01-01 23:00:00 ┆ 21 │ - │ 24 ┆ 2001-01-02 00:00:00 ┆ 22 │ + │ 20 ┆ 2001-01-01 20:00:00 ┆ 19 │ + │ 21 ┆ 2001-01-01 21:00:00 ┆ 20 │ + │ 22 ┆ 2001-01-01 22:00:00 ┆ 21 │ + │ 23 ┆ 2001-01-01 23:00:00 ┆ 22 │ + │ 24 ┆ 2001-01-02 00:00:00 ┆ 23 │ └───────â”ī─────────────────────â”ī─────────────────┘ """ window_size = deprecate_saturating(window_size) @@ -6447,12 +6448,10 @@ def rolling_max( │ 24 ┆ 2001-01-02 00:00:00 │ └───────â”ī─────────────────────┘ - Compute the rolling max with the default left closure of temporal windows + Compute the rolling max with the temporal windows closed on the right (default) >>> df_temporal.with_columns( - ... rolling_row_max=pl.col("index").rolling_max( - ... window_size="2h", by="date", closed="left" - ... ) + ... rolling_row_max=pl.col("index").rolling_max(window_size="2h", by="date") ... ) shape: (25, 3) ┌───────┮─────────────────────┮─────────────────┐ @@ -6460,17 +6459,17 @@ def rolling_max( │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[Ξs] ┆ u32 │ ╞═══════╩═════════════════════╩═════════════════╡ - │ 0 ┆ 2001-01-01 00:00:00 ┆ null │ - │ 1 ┆ 2001-01-01 01:00:00 ┆ 0 │ - │ 2 ┆ 2001-01-01 02:00:00 ┆ 1 │ - │ 3 ┆ 2001-01-01 03:00:00 ┆ 2 │ - │ 4 ┆ 2001-01-01 04:00:00 ┆ 3 │ + │ 0 ┆ 2001-01-01 00:00:00 ┆ 0 │ + │ 1 ┆ 2001-01-01 01:00:00 ┆ 1 │ + │ 2 ┆ 2001-01-01 02:00:00 ┆ 2 │ + │ 3 ┆ 2001-01-01 03:00:00 ┆ 3 │ + │ 4 ┆ 2001-01-01 04:00:00 ┆ 4 │ │ â€Ķ ┆ â€Ķ ┆ â€Ķ │ - │ 20 ┆ 2001-01-01 20:00:00 ┆ 19 │ - │ 21 ┆ 2001-01-01 21:00:00 ┆ 20 │ - │ 22 ┆ 2001-01-01 22:00:00 ┆ 21 │ - │ 23 ┆ 2001-01-01 23:00:00 ┆ 22 │ - │ 24 ┆ 2001-01-02 00:00:00 ┆ 23 │ + │ 20 ┆ 2001-01-01 20:00:00 ┆ 20 │ + │ 21 ┆ 2001-01-01 21:00:00 ┆ 21 │ + │ 22 ┆ 2001-01-01 22:00:00 ┆ 22 │ + │ 23 ┆ 2001-01-01 23:00:00 ┆ 23 │ + │ 24 ┆ 2001-01-02 00:00:00 ┆ 24 │ └───────â”ī─────────────────────â”ī─────────────────┘ Compute the rolling max with the closure of windows on both sides @@ -6688,11 +6687,11 @@ def rolling_mean( │ 24 ┆ 2001-01-02 00:00:00 │ └───────â”ī─────────────────────┘ - Compute the rolling mean with the default left closure of temporal windows + Compute the rolling mean with the temporal windows closed on the right (default) >>> df_temporal.with_columns( ... rolling_row_mean=pl.col("index").rolling_mean( - ... window_size="2h", by="date", closed="left" + ... window_size="2h", by="date" ... ) ... ) shape: (25, 3) @@ -6701,17 +6700,17 @@ def rolling_mean( │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[Ξs] ┆ f64 │ ╞═══════╩═════════════════════╩══════════════════╡ - │ 0 ┆ 2001-01-01 00:00:00 ┆ null │ - │ 1 ┆ 2001-01-01 01:00:00 ┆ 0.0 │ - │ 2 ┆ 2001-01-01 02:00:00 ┆ 0.5 │ - │ 3 ┆ 2001-01-01 03:00:00 ┆ 1.5 │ - │ 4 ┆ 2001-01-01 04:00:00 ┆ 2.5 │ + │ 0 ┆ 2001-01-01 00:00:00 ┆ 0.0 │ + │ 1 ┆ 2001-01-01 01:00:00 ┆ 0.5 │ + │ 2 ┆ 2001-01-01 02:00:00 ┆ 1.5 │ + │ 3 ┆ 2001-01-01 03:00:00 ┆ 2.5 │ + │ 4 ┆ 2001-01-01 04:00:00 ┆ 3.5 │ │ â€Ķ ┆ â€Ķ ┆ â€Ķ │ - │ 20 ┆ 2001-01-01 20:00:00 ┆ 18.5 │ - │ 21 ┆ 2001-01-01 21:00:00 ┆ 19.5 │ - │ 22 ┆ 2001-01-01 22:00:00 ┆ 20.5 │ - │ 23 ┆ 2001-01-01 23:00:00 ┆ 21.5 │ - │ 24 ┆ 2001-01-02 00:00:00 ┆ 22.5 │ + │ 20 ┆ 2001-01-01 20:00:00 ┆ 19.5 │ + │ 21 ┆ 2001-01-01 21:00:00 ┆ 20.5 │ + │ 22 ┆ 2001-01-01 22:00:00 ┆ 21.5 │ + │ 23 ┆ 2001-01-01 23:00:00 ┆ 22.5 │ + │ 24 ┆ 2001-01-02 00:00:00 ┆ 23.5 │ └───────â”ī─────────────────────â”ī──────────────────┘ Compute the rolling mean with the closure of windows on both sides @@ -6931,12 +6930,10 @@ def rolling_sum( │ 24 ┆ 2001-01-02 00:00:00 │ └───────â”ī─────────────────────┘ - Compute the rolling sum with the default left closure of temporal windows + Compute the rolling sum with the temporal windows closed on the right (default) >>> df_temporal.with_columns( - ... rolling_row_sum=pl.col("index").rolling_sum( - ... window_size="2h", by="date", closed="left" - ... ) + ... rolling_row_sum=pl.col("index").rolling_sum(window_size="2h", by="date") ... ) shape: (25, 3) ┌───────┮─────────────────────┮─────────────────┐ @@ -6944,17 +6941,17 @@ def rolling_sum( │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[Ξs] ┆ u32 │ ╞═══════╩═════════════════════╩═════════════════╡ - │ 0 ┆ 2001-01-01 00:00:00 ┆ null │ - │ 1 ┆ 2001-01-01 01:00:00 ┆ 0 │ - │ 2 ┆ 2001-01-01 02:00:00 ┆ 1 │ - │ 3 ┆ 2001-01-01 03:00:00 ┆ 3 │ - │ 4 ┆ 2001-01-01 04:00:00 ┆ 5 │ + │ 0 ┆ 2001-01-01 00:00:00 ┆ 0 │ + │ 1 ┆ 2001-01-01 01:00:00 ┆ 1 │ + │ 2 ┆ 2001-01-01 02:00:00 ┆ 3 │ + │ 3 ┆ 2001-01-01 03:00:00 ┆ 5 │ + │ 4 ┆ 2001-01-01 04:00:00 ┆ 7 │ │ â€Ķ ┆ â€Ķ ┆ â€Ķ │ - │ 20 ┆ 2001-01-01 20:00:00 ┆ 37 │ - │ 21 ┆ 2001-01-01 21:00:00 ┆ 39 │ - │ 22 ┆ 2001-01-01 22:00:00 ┆ 41 │ - │ 23 ┆ 2001-01-01 23:00:00 ┆ 43 │ - │ 24 ┆ 2001-01-02 00:00:00 ┆ 45 │ + │ 20 ┆ 2001-01-01 20:00:00 ┆ 39 │ + │ 21 ┆ 2001-01-01 21:00:00 ┆ 41 │ + │ 22 ┆ 2001-01-01 22:00:00 ┆ 43 │ + │ 23 ┆ 2001-01-01 23:00:00 ┆ 45 │ + │ 24 ┆ 2001-01-02 00:00:00 ┆ 47 │ └───────â”ī─────────────────────â”ī─────────────────┘ Compute the rolling sum with the closure of windows on both sides @@ -7172,12 +7169,10 @@ def rolling_std( │ 24 ┆ 2001-01-02 00:00:00 │ └───────â”ī─────────────────────┘ - Compute the rolling std with the default left closure of temporal windows + Compute the rolling std with the temporal windows closed on the right (default) >>> df_temporal.with_columns( - ... rolling_row_std=pl.col("index").rolling_std( - ... window_size="2h", by="date", closed="left" - ... ) + ... rolling_row_std=pl.col("index").rolling_std(window_size="2h", by="date") ... ) shape: (25, 3) ┌───────┮─────────────────────┮─────────────────┐ @@ -7186,7 +7181,7 @@ def rolling_std( │ u32 ┆ datetime[Ξs] ┆ f64 │ ╞═══════╩═════════════════════╩═════════════════╡ │ 0 ┆ 2001-01-01 00:00:00 ┆ null │ - │ 1 ┆ 2001-01-01 01:00:00 ┆ null │ + │ 1 ┆ 2001-01-01 01:00:00 ┆ 0.707107 │ │ 2 ┆ 2001-01-01 02:00:00 ┆ 0.707107 │ │ 3 ┆ 2001-01-01 03:00:00 ┆ 0.707107 │ │ 4 ┆ 2001-01-01 04:00:00 ┆ 0.707107 │ @@ -7419,12 +7414,10 @@ def rolling_var( │ 24 ┆ 2001-01-02 00:00:00 │ └───────â”ī─────────────────────┘ - Compute the rolling var with the default left closure of temporal windows + Compute the rolling var with the temporal windows closed on the right (default) >>> df_temporal.with_columns( - ... rolling_row_var=pl.col("index").rolling_var( - ... window_size="2h", by="date", closed="left" - ... ) + ... rolling_row_var=pl.col("index").rolling_var(window_size="2h", by="date") ... ) shape: (25, 3) ┌───────┮─────────────────────┮─────────────────┐ @@ -7433,7 +7426,7 @@ def rolling_var( │ u32 ┆ datetime[Ξs] ┆ f64 │ ╞═══════╩═════════════════════╩═════════════════╡ │ 0 ┆ 2001-01-01 00:00:00 ┆ null │ - │ 1 ┆ 2001-01-01 01:00:00 ┆ null │ + │ 1 ┆ 2001-01-01 01:00:00 ┆ 0.5 │ │ 2 ┆ 2001-01-01 02:00:00 ┆ 0.5 │ │ 3 ┆ 2001-01-01 03:00:00 ┆ 0.5 │ │ 4 ┆ 2001-01-01 04:00:00 ┆ 0.5 │ diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index f6099f22bb91..bacd0141c266 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -3402,7 +3402,7 @@ def top_k(self, k: int | IntoExprColumn = 5) -> Series: This has time complexity: - .. math:: O(n + k \\log{}n - \frac{k}{2}) + .. math:: O(n + k \log{}n - \frac{k}{2}) Parameters ---------- @@ -3432,7 +3432,7 @@ def bottom_k(self, k: int | IntoExprColumn = 5) -> Series: This has time complexity: - .. math:: O(n + k \\log{}n - \frac{k}{2}) + .. math:: O(n + k \log{}n - \frac{k}{2}) Parameters ---------- diff --git a/py-polars/polars/testing/parametric/primitives.py b/py-polars/polars/testing/parametric/primitives.py index bf5c30e19c88..fc41af3e8b7a 100644 --- a/py-polars/polars/testing/parametric/primitives.py +++ b/py-polars/polars/testing/parametric/primitives.py @@ -28,10 +28,10 @@ from polars.string_cache import StringCache from polars.testing.parametric.strategies import ( _flexhash, - all_strategies, between, create_array_strategy, create_list_strategy, + dtype_strategies, scalar_strategies, ) @@ -381,11 +381,7 @@ def draw_series(draw: DrawFn) -> Series: if strategy is None: if series_dtype is Datetime or series_dtype is Duration: series_dtype = series_dtype(random.choice(_time_units)) # type: ignore[operator] - dtype_strategy = all_strategies[ - series_dtype - if series_dtype in all_strategies - else series_dtype.base_type() - ] + dtype_strategy = draw(dtype_strategies(series_dtype)) else: dtype_strategy = strategy diff --git a/py-polars/polars/testing/parametric/strategies.py b/py-polars/polars/testing/parametric/strategies.py index 7e03e3808e36..2cfc3626c478 100644 --- a/py-polars/polars/testing/parametric/strategies.py +++ b/py-polars/polars/testing/parametric/strategies.py @@ -1,6 +1,7 @@ from __future__ import annotations from datetime import datetime, timedelta +from decimal import Decimal as PyDecimal from itertools import chain from random import choice, randint, shuffle from string import ascii_uppercase @@ -14,6 +15,7 @@ Sequence, ) +import hypothesis.strategies as st from hypothesis.strategies import ( SearchStrategy, binary, @@ -22,7 +24,6 @@ composite, dates, datetimes, - decimals, floats, from_type, integers, @@ -56,13 +57,11 @@ UInt16, UInt32, UInt64, - is_polars_dtype, ) from polars.type_aliases import PolarsDataType if TYPE_CHECKING: import sys - from decimal import Decimal as PyDecimal from hypothesis.strategies import DrawFn @@ -72,6 +71,26 @@ from typing_extensions import Self +@composite +def dtype_strategies(draw: DrawFn, dtype: PolarsDataType) -> SearchStrategy[Any]: + """Returns a strategy which generates valid values for the given data type.""" + if (strategy := all_strategies.get(dtype)) is not None: + return strategy + elif (strategy_base := all_strategies.get(dtype.base_type())) is not None: + return strategy_base + + if dtype == Decimal: + return draw( + decimal_strategies( + precision=getattr(dtype, "precision", None), + scale=getattr(dtype, "scale", None), + ) + ) + else: + msg = f"unsupported data type: {dtype}" + raise TypeError(msg) + + def between(draw: DrawFn, type_: type, min_: Any, max_: Any) -> Any: """Draw a value in a given range from a type-inferred strategy.""" strategy_init = from_type(type_).function # type: ignore[attr-defined] @@ -117,19 +136,28 @@ def between(draw: DrawFn, type_: type, min_: Any, max_: Any) -> Any: @composite -def strategy_decimal(draw: DrawFn) -> PyDecimal: - """Draw a decimal value, varying the number of decimal places.""" - places = draw(integers(min_value=0, max_value=18)) - return draw( - # TODO: once fixed, re-enable decimal nan/inf values... - # (see https://github.com/pola-rs/polars/issues/8421) - decimals( - allow_nan=False, - allow_infinity=False, - min_value=-(2**66), - max_value=(2**66) - 1, - places=places, - ) +def decimal_strategies( + draw: DrawFn, precision: int | None = None, scale: int | None = None +) -> SearchStrategy[PyDecimal]: + """Returns a strategy which generates instances of Python `Decimal`.""" + if precision is None: + precision = draw(integers(min_value=scale or 1, max_value=38)) + if scale is None: + scale = draw(integers(min_value=0, max_value=precision)) + + exclusive_limit = PyDecimal(f"1E+{precision - scale}") + epsilon = PyDecimal(f"1E-{scale}") + limit = exclusive_limit - epsilon + if limit == exclusive_limit: # Limit cannot be set exactly due to precision issues + multiplier = PyDecimal("1") - PyDecimal("1E-20") # 0.999... + limit = limit * multiplier + + return st.decimals( + allow_nan=False, + allow_infinity=False, + min_value=-limit, + max_value=limit, + places=scale, ) @@ -272,34 +300,15 @@ def update(self, items: StrategyLookup) -> Self: # type: ignore[override] Categorical: strategy_categorical, String: strategy_string, Binary: strategy_binary, - Decimal: strategy_decimal(), } ) nested_strategies: StrategyLookup = StrategyLookup() -def _get_strategy_dtypes( - *, - base_type: bool = False, - excluding: tuple[PolarsDataType] | PolarsDataType | None = None, -) -> list[PolarsDataType]: - """ - Get a list of all the dtypes for which we have a strategy. - - Parameters - ---------- - base_type - If True, return the base types for each dtype (eg:`List(String)` → `List`). - excluding - A dtype or sequence of dtypes to omit from the results. - """ - excluding = (excluding,) if is_polars_dtype(excluding) else (excluding or ()) # type: ignore[assignment] +def _get_strategy_dtypes() -> list[PolarsDataType]: + """Get a list of all the dtypes for which we have a strategy.""" strategy_dtypes = list(chain(scalar_strategies.keys(), nested_strategies.keys())) - return [ - (tp.base_type() if base_type else tp) - for tp in strategy_dtypes - if tp not in excluding # type: ignore[operator] - ] + return [tp.base_type() for tp in strategy_dtypes] def _flexhash(elem: Any) -> int: @@ -351,7 +360,7 @@ def create_array_strategy( width = randint(a=1, b=8) if inner_dtype is None: - strats = list(_get_strategy_dtypes(base_type=True)) + strats = list(_get_strategy_dtypes()) shuffle(strats) inner_dtype = choice(strats) @@ -431,7 +440,7 @@ def create_list_strategy( raise ValueError(msg) if inner_dtype is None: - strats = list(_get_strategy_dtypes(base_type=True)) + strats = list(_get_strategy_dtypes()) shuffle(strats) inner_dtype = choice(strats) if size: diff --git a/py-polars/src/functions/lazy.rs b/py-polars/src/functions/lazy.rs index b3c22dd6cca2..5266b9180aa3 100644 --- a/py-polars/src/functions/lazy.rs +++ b/py-polars/src/functions/lazy.rs @@ -1,5 +1,6 @@ use polars::lazy::dsl; use polars::prelude::*; +use polars_plan::prelude::UnionArgs; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; use pyo3::types::{PyBool, PyBytes, PyFloat, PyInt, PyString}; @@ -172,6 +173,7 @@ pub fn concat_lf( rechunk, parallel, to_supertypes, + ..Default::default() }, ) .map_err(PyPolarsErr::from)?; @@ -288,6 +290,7 @@ pub fn concat_lf_diagonal( rechunk, parallel, to_supertypes, + ..Default::default() }, ) .map_err(PyPolarsErr::from)?; @@ -309,6 +312,7 @@ pub fn concat_lf_horizontal(lfs: &PyAny, parallel: bool) -> PyResult None: df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).sort("a") - with pytest.raises(InvalidOperationError, match="`rolling_min` operation"): + msg = "in `rolling_min` operation, `by` argument of dtype `i64` is not supported" + with pytest.raises(InvalidOperationError, match=msg): df.select(pl.col("b").rolling_min(2, by="a")) + df = pl.DataFrame({"a": [1, 2, 3], "b": [date(2020, 1, 1)] * 3}).sort("b") + msg = "if `by` argument is passed, then `window_size` must be a temporal window" + with pytest.raises(InvalidOperationError, match=msg): + df.select(pl.col("a").rolling_min(2, by="b")) def test_rolling_infinity() -> None: @@ -236,10 +241,26 @@ def test_rolling_invalid_closed_option() -> None: ).sort("a", "b") with pytest.raises(InvalidOperationError, match="consider using DataFrame.rolling"): df.with_columns(pl.col("a").rolling_sum(2, closed="left")) - with pytest.raises(InvalidOperationError, match="consider using DataFrame.rolling"): + + +def test_rolling_by_non_temporal_window_size() -> None: + df = pl.DataFrame( + {"a": [4, 5, 6], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} + ).sort("a", "b") + msg = "if `by` argument is passed, then `window_size` must be a temporal window" + with pytest.raises(InvalidOperationError, match=msg): df.with_columns(pl.col("a").rolling_sum(2, by="b", closed="left")) +def test_rolling_by_weights() -> None: + df = pl.DataFrame( + {"a": [4, 5, 6], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} + ).sort("b") + msg = r"`weights` is not supported in 'rolling_\*\(..., by=...\)' expression" + with pytest.raises(InvalidOperationError, match=msg): + df.with_columns(pl.col("a").rolling_sum("2d", by="b", weights=[1, 2])) + + def test_rolling_extrema() -> None: # sorted data and nulls flags trigger different kernels df = ( @@ -955,3 +976,103 @@ def test_rolling_invalid() -> None: .rolling("index", period="3000d") .agg(pl.col("values").sum().alias("sum")) ) + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_rolling_duration(time_unit: Literal["ns", "us", "ms"]) -> None: + # Here we only test for consistency with datetime. + df = pl.DataFrame( + { + "index_column": [1, 2, 3, 4, 5], + "value": [ + 1, + 10, + 100, + 1000, + 10000, + ], + } + ) + df_duration = df.select( + pl.col("index_column").cast(pl.Duration(time_unit=time_unit)).set_sorted(), + "value", + ) + + df_datetime = df.select( + pl.col("index_column").cast(pl.Datetime(time_unit=time_unit)).set_sorted(), + "value", + ) + + res_duration = df_duration.rolling( + index_column="index_column", period=f"2{time_unit}" + ).agg(pl.col("value").sum()) + + res_datetime = df_datetime.rolling( + index_column="index_column", period=f"2{time_unit}" + ).agg(pl.col("value").sum()) + + assert ( + res_duration["value"].to_list() == res_datetime["value"].to_list() + ), f"{res_duration['value'].to_list()}, {res_datetime['value'].to_list()}" + + assert res_duration["index_column"].dtype == pl.Duration(time_unit=time_unit) + + +def test_temporal_windows_size_without_by_15977() -> None: + df = pl.DataFrame( + {"a": [1, 2, 3], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} + ) + with pytest.raises( + pl.InvalidOperationError, match="the `by` argument must be passed" + ): + df.select(pl.col("a").rolling_mean("3d")) + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_rolling_duration(time_unit: Literal["ns", "us", "ms"]) -> None: + # Here we only test for consistency with datetime. + df = pl.DataFrame( + { + "index_column": [1, 2, 3, 4, 5], + "value": [ + 1, + 10, + 100, + 1000, + 10000, + ], + } + ) + df_duration = df.select( + pl.col("index_column").cast(pl.Duration(time_unit=time_unit)).set_sorted(), + "value", + ) + + df_datetime = df.select( + pl.col("index_column").cast(pl.Datetime(time_unit=time_unit)).set_sorted(), + "value", + ) + + res_duration = df_duration.rolling( + index_column="index_column", period=f"2{time_unit}" + ).agg(pl.col("value").sum()) + + res_datetime = df_datetime.rolling( + index_column="index_column", period=f"2{time_unit}" + ).agg(pl.col("value").sum()) + + assert ( + res_duration["value"].to_list() == res_datetime["value"].to_list() + ), f"{res_duration['value'].to_list()=}, {res_datetime['value'].to_list()=}" + + assert res_duration["index_column"].dtype == pl.Duration(time_unit=time_unit) + + +def test_temporal_windows_size_without_by_15977() -> None: + df = pl.DataFrame( + {"a": [1, 2, 3], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} + ) + with pytest.raises( + pl.InvalidOperationError, match="the `by` argument must be passed" + ): + df.select(pl.col("a").rolling_mean("3d")) diff --git a/py-polars/tests/unit/operations/test_cut.py b/py-polars/tests/unit/operations/test_cut.py index b87381e94eff..25003028a23e 100644 --- a/py-polars/tests/unit/operations/test_cut.py +++ b/py-polars/tests/unit/operations/test_cut.py @@ -120,3 +120,47 @@ def test_cut_bin_name_in_agg_context() -> None: ) schema = pl.Struct({"brk": pl.Float64, "a_bin": pl.Categorical("physical")}) assert df.schema == {"cut": schema, "qcut": schema, "qcut_uniform": schema} + + +@pytest.mark.parametrize( + ("breaks", "expected_labels", "expected_physical", "expected_unique"), + [ + ( + [2, 4], + pl.Series("x", ["(-inf, 2]", "(-inf, 2]", "(2, 4]", "(2, 4]", "(4, inf]"]), + pl.Series("x", [0, 0, 1, 1, 2], dtype=pl.UInt32), + 3, + ), + ( + [99, 101], + pl.Series("x", 5 * ["(-inf, 99]"]), + pl.Series("x", 5 * [0], dtype=pl.UInt32), + 1, + ), + ], +) +def test_cut_fast_unique_15981( + breaks: list[int], + expected_labels: pl.Series, + expected_physical: pl.Series, + expected_unique: int, +) -> None: + s = pl.Series("x", [1, 2, 3, 4, 5]) + + include_breaks = False + s_cut = s.cut(breaks, include_breaks=include_breaks) + + assert_series_equal(s_cut.cast(pl.String), expected_labels) + assert_series_equal(s_cut.to_physical(), expected_physical) + assert s_cut.n_unique() == s_cut.to_physical().n_unique() == expected_unique + s_cut.to_frame().group_by(s.name).len() + + include_breaks = True + s_cut = ( + s.cut(breaks, include_breaks=include_breaks).struct.field("category").alias("x") + ) + + assert_series_equal(s_cut.cast(pl.String), expected_labels) + assert_series_equal(s_cut.to_physical(), expected_physical) + assert s_cut.n_unique() == s_cut.to_physical().n_unique() == expected_unique + s_cut.to_frame().group_by(s.name).len() diff --git a/py-polars/tests/unit/sql/test_cast.py b/py-polars/tests/unit/sql/test_cast.py index 22ffbfceb4aa..0f5cc61c1dc9 100644 --- a/py-polars/tests/unit/sql/test_cast.py +++ b/py-polars/tests/unit/sql/test_cast.py @@ -1,8 +1,11 @@ from __future__ import annotations +from typing import Any + import pytest import polars as pl +import polars.selectors as cs from polars.exceptions import ComputeError from polars.testing import assert_frame_equal @@ -14,6 +17,7 @@ def test_cast() -> None: "b": [1.1, 2.2, 3.3, 4.4, 5.5], "c": ["a", "b", "c", "d", "e"], "d": [True, False, True, False, True], + "e": [-1, 0, None, 1, 2], } ) # test various dtype casts, using standard ("CAST AS ") @@ -25,11 +29,29 @@ def test_cast() -> None: -- float CAST(a AS DOUBLE PRECISION) AS a_f64, a::real AS a_f32, + b::float(24) AS b_f32, + b::float(25) AS b_f64, + e::float8 AS e_f64, + e::float4 AS e_f32, + -- integer CAST(b AS TINYINT) AS b_i8, CAST(b AS SMALLINT) AS b_i16, b::bigint AS b_i64, d::tinyint AS d_i8, + a::int1 AS a_i8, + a::int2 AS a_i16, + a::int4 AS a_i32, + a::int8 AS a_i64, + + -- unsigned integer + CAST(a AS TINYINT UNSIGNED) AS a_u8, + d::uint1 AS d_u8, + a::uint2 AS a_u16, + b::uint4 AS b_u32, + b::uint8 AS b_u64, + CAST(a AS BIGINT UNSIGNED) AS a_u64, + -- string/binary CAST(a AS CHAR) AS a_char, CAST(b AS VARCHAR) AS b_varchar, @@ -37,29 +59,86 @@ def test_cast() -> None: c::bytes AS c_bytes, c::VARBINARY AS c_varbinary, CAST(d AS CHARACTER VARYING) AS d_charvar, + + -- boolean + e::bool AS e_bool, + e::boolean AS e_boolean FROM df """ ) assert res.schema == { "a_f64": pl.Float64, "a_f32": pl.Float32, + "b_f32": pl.Float32, + "b_f64": pl.Float64, + "e_f64": pl.Float64, + "e_f32": pl.Float32, "b_i8": pl.Int8, "b_i16": pl.Int16, "b_i64": pl.Int64, "d_i8": pl.Int8, + "a_i8": pl.Int8, + "a_i16": pl.Int16, + "a_i32": pl.Int32, + "a_i64": pl.Int64, + "a_u8": pl.UInt8, + "d_u8": pl.UInt8, + "a_u16": pl.UInt16, + "b_u32": pl.UInt32, + "b_u64": pl.UInt64, + "a_u64": pl.UInt64, "a_char": pl.String, "b_varchar": pl.String, "c_blob": pl.Binary, "c_bytes": pl.Binary, "c_varbinary": pl.Binary, "d_charvar": pl.String, + "e_bool": pl.Boolean, + "e_boolean": pl.Boolean, } - assert res.rows() == [ - (1.0, 1.0, 1, 1, 1, 1, "1", "1.1", b"a", b"a", b"a", "true"), - (2.0, 2.0, 2, 2, 2, 0, "2", "2.2", b"b", b"b", b"b", "false"), - (3.0, 3.0, 3, 3, 3, 1, "3", "3.3", b"c", b"c", b"c", "true"), - (4.0, 4.0, 4, 4, 4, 0, "4", "4.4", b"d", b"d", b"d", "false"), - (5.0, 5.0, 5, 5, 5, 1, "5", "5.5", b"e", b"e", b"e", "true"), + assert res.select(cs.by_dtype(pl.Float32)).rows() == pytest.approx( + [ + (1.0, 1.100000023841858, -1.0), + (2.0, 2.200000047683716, 0.0), + (3.0, 3.299999952316284, None), + (4.0, 4.400000095367432, 1.0), + (5.0, 5.5, 2.0), + ] + ) + assert res.select(cs.by_dtype(pl.Float64)).rows() == [ + (1.0, 1.1, -1.0), + (2.0, 2.2, 0.0), + (3.0, 3.3, None), + (4.0, 4.4, 1.0), + (5.0, 5.5, 2.0), + ] + assert res.select(cs.integer()).rows() == [ + (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), + (2, 2, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2), + (3, 3, 3, 1, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3), + (4, 4, 4, 0, 4, 4, 4, 4, 4, 0, 4, 4, 4, 4), + (5, 5, 5, 1, 5, 5, 5, 5, 5, 1, 5, 5, 5, 5), + ] + assert res.select(cs.string()).rows() == [ + ("1", "1.1", "true"), + ("2", "2.2", "false"), + ("3", "3.3", "true"), + ("4", "4.4", "false"), + ("5", "5.5", "true"), + ] + assert res.select(cs.binary()).rows() == [ + (b"a", b"a", b"a"), + (b"b", b"b", b"b"), + (b"c", b"c", b"c"), + (b"d", b"d", b"d"), + (b"e", b"e", b"e"), + ] + assert res.select(cs.boolean()).rows() == [ + (True, True), + (False, False), + (None, None), + (True, True), + (True, True), ] with pytest.raises(ComputeError, match="unsupported use of FORMAT in CAST"): @@ -68,6 +147,24 @@ def test_cast() -> None: ) +@pytest.mark.parametrize( + ("values", "cast_op", "error"), + [ + ([1.0, -1.0], "values::uint8", "conversion from `f64` to `u64` failed"), + ([10, 0, -1], "values::uint4", "conversion from `i64` to `u32` failed"), + ([int(1e8)], "values::int1", "conversion from `i64` to `i8` failed"), + (["a", "b"], "values::date", "conversion from `str` to `date` failed"), + (["a", "b"], "values::time", "conversion from `str` to `time` failed"), + (["a", "b"], "values::int4", "conversion from `str` to `i32` failed"), + ], +) +def test_cast_errors(values: Any, cast_op: str, error: str) -> None: + df = pl.DataFrame({"values": values}) + + with pytest.raises(ComputeError, match=error): + df.sql(f"SELECT {cast_op} FROM df") + + def test_cast_json() -> None: df = pl.DataFrame({"txt": ['{"a":[1,2,3],"b":["x","y","z"],"c":5.0}']}) diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py index d400b8127f96..6ce6cfc4621e 100644 --- a/py-polars/tests/unit/test_cse.py +++ b/py-polars/tests/unit/test_cse.py @@ -723,3 +723,12 @@ def test_cse_drop_nulls_15795() -> None: C = A.join(B, on="X").select("X") D = B.select("X") assert C.join(D, on="X").collect().shape == (1, 1) + + +def test_cse_no_projection_15980() -> None: + df = pl.LazyFrame({"x": "a", "y": 1}) + df = pl.concat(df.with_columns(pl.col("y").add(n)) for n in range(2)) + + assert df.filter(pl.col("x").eq("a")).select("x").collect().to_dict( + as_series=False + ) == {"x": ["a", "a"]}