From 9164d2cdfd35d9c38717d6611abfdbc85e107e84 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Thu, 20 Feb 2025 21:24:11 +1100 Subject: [PATCH] fix(rust): Fix height validation in `hstack_mut` was bypassed when adding to empty frame (#21335) --- crates/polars-core/src/frame/horizontal.rs | 148 +++++++++--------- crates/polars-core/src/frame/mod.rs | 53 ++----- crates/polars-core/src/frame/validation.rs | 67 ++++++++ crates/polars-ops/src/frame/pivot/mod.rs | 3 +- .../polars-ops/src/series/ops/to_dummies.rs | 3 +- crates/polars-schema/src/schema.rs | 22 ++- py-polars/tests/unit/dataframe/test_df.py | 2 +- 7 files changed, 167 insertions(+), 131 deletions(-) create mode 100644 crates/polars-core/src/frame/validation.rs diff --git a/crates/polars-core/src/frame/horizontal.rs b/crates/polars-core/src/frame/horizontal.rs index 6d2500b03568..932d37708092 100644 --- a/crates/polars-core/src/frame/horizontal.rs +++ b/crates/polars-core/src/frame/horizontal.rs @@ -1,28 +1,8 @@ -use polars_error::{polars_ensure, polars_err, PolarsResult}; -use polars_utils::aliases::PlHashSet; +use polars_error::{polars_err, PolarsResult}; use super::Column; use crate::datatypes::AnyValue; use crate::frame::DataFrame; -use crate::prelude::PlSmallStr; - -fn check_hstack( - col: &Column, - names: &mut PlHashSet, - height: usize, - is_empty: bool, -) -> PolarsResult<()> { - polars_ensure!( - col.len() == height || is_empty, - ShapeMismatch: "unable to hstack Series of length {} and DataFrame of height {}", - col.len(), height, - ); - polars_ensure!( - names.insert(col.name().clone()), - Duplicate: "unable to hstack, column with name {:?} already exists", col.name().as_str(), - ); - Ok(()) -} impl DataFrame { /// Add columns horizontally. @@ -31,28 +11,35 @@ impl DataFrame { /// The caller must ensure: /// - the length of all [`Column`] is equal to the height of this [`DataFrame`] /// - the columns names are unique + /// + /// Note: If `self` is empty, `self.height` will always be overridden by the height of the first + /// column in `columns`. + /// + /// Note that on a debug build this will panic on duplicates / height mismatch. pub unsafe fn hstack_mut_unchecked(&mut self, columns: &[Column]) -> &mut Self { - // If we don't have any columns yet, copy the height from the given columns. - if let Some(fst) = columns.first() { - if self.width() == 0 { - // SAFETY: The functions invariants asks for all columns to be the same length so - // that makes that a valid height. - unsafe { self.set_height(fst.len()) }; + self.clear_schema(); + self.columns.extend_from_slice(columns); + + if cfg!(debug_assertions) { + if let err @ Err(_) = DataFrame::validate_columns_slice(&self.columns) { + // Reset DataFrame state to before extend. + self.columns.truncate(self.columns.len() - columns.len()); + err.unwrap(); } } - if cfg!(debug_assertions) { - // It is an impl error if this fails. - self._validate_hstack(columns).unwrap(); + if let Some(c) = self.columns.first() { + unsafe { self.set_height(c.len()) }; } - self.clear_schema(); - self.columns.extend_from_slice(columns); self } /// Add multiple [`Column`] to a [`DataFrame`]. - /// The added `Series` are required to have the same length. + /// Errors if the resulting DataFrame columns have duplicate names or unequal heights. + /// + /// Note: If `self` is empty, `self.height` will always be overridden by the height of the first + /// column in `columns`. /// /// # Example /// @@ -63,28 +50,23 @@ impl DataFrame { /// } /// ``` pub fn hstack_mut(&mut self, columns: &[Column]) -> PolarsResult<&mut Self> { - self._validate_hstack(columns)?; - Ok(unsafe { self.hstack_mut_unchecked(columns) }) - } + self.clear_schema(); + self.columns.extend_from_slice(columns); - fn _validate_hstack(&self, columns: &[Column]) -> PolarsResult<()> { - let mut names = self - .columns - .iter() - .map(|c| c.name().clone()) - .collect::>(); - - let height = self.height(); - let is_empty = self.is_empty(); - // first loop check validity. We don't do this in a single pass otherwise - // this DataFrame is already modified when an error occurs. - for col in columns { - check_hstack(col, &mut names, height, is_empty)?; + if let err @ Err(_) = DataFrame::validate_columns_slice(&self.columns) { + // Reset DataFrame state to before extend. + self.columns.truncate(self.columns.len() - columns.len()); + err?; } - Ok(()) + if let Some(c) = self.columns.first() { + unsafe { self.set_height(c.len()) }; + } + + Ok(self) } } + /// Concat [`DataFrame`]s horizontally. /// Concat horizontally and extend with null values if lengths don't match pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> PolarsResult { @@ -96,12 +78,23 @@ pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> Polars let owned_df; + let mut out_width = 0; + + let all_equal_height = dfs.iter().all(|df| { + out_width += df.width(); + df.height() == output_height + }); + // if not all equal length, extend the DataFrame with nulls - let dfs = if !dfs.iter().all(|df| df.height() == output_height) { + let dfs = if !all_equal_height { + out_width = 0; + owned_df = dfs .iter() .cloned() .map(|mut df| { + out_width += df.width(); + if df.height() != output_height { let diff = output_height - df.height(); @@ -123,30 +116,41 @@ pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> Polars dfs }; - let mut first_df = dfs[0].clone(); - let height = first_df.height(); - let is_empty = first_df.is_empty(); + let mut acc_cols = Vec::with_capacity(out_width); - let mut names = if check_duplicates { - first_df - .columns - .iter() - .map(|s| s.name().clone()) - .collect::>() - } else { - Default::default() - }; + for df in dfs { + acc_cols.extend(df.get_columns().iter().cloned()); + } - for df in &dfs[1..] { - let cols = df.get_columns(); + if check_duplicates { + DataFrame::validate_columns_slice(&acc_cols)?; + } - if check_duplicates { - for col in cols { - check_hstack(col, &mut names, height, is_empty)?; - } - } + let df = unsafe { DataFrame::new_no_checks_height_from_first(acc_cols) }; + + Ok(df) +} - unsafe { first_df.hstack_mut_unchecked(cols) }; +#[cfg(test)] +mod tests { + use polars_error::PolarsError; + + #[test] + fn test_hstack_mut_empty_frame_height_validation() { + use crate::frame::DataFrame; + use crate::prelude::{Column, DataType}; + let mut df = DataFrame::empty(); + let result = df.hstack_mut(&[ + Column::full_null("a".into(), 1, &DataType::Null), + Column::full_null("b".into(), 3, &DataType::Null), + ]); + + assert!( + matches!(result, Err(PolarsError::ShapeMismatch(_))), + "expected shape mismatch error" + ); + + // Ensure the DataFrame is not mutated in the error case. + assert_eq!(df.width(), 0); } - Ok(first_df) } diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index e1babf6ed0aa..44fdf6c68f34 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -4,7 +4,7 @@ use std::{mem, ops}; use arrow::datatypes::ArrowSchemaRef; use polars_row::ArrayRef; -use polars_schema::schema::debug_ensure_matching_schema_names; +use polars_schema::schema::ensure_matching_schema_names; use polars_utils::itertools::Itertools; use rayon::prelude::*; @@ -31,6 +31,7 @@ pub(crate) mod horizontal; pub mod row; mod top_k; mod upstream_traits; +mod validation; use arrow::record_batch::{RecordBatch, RecordBatchT}; use polars_utils::pl_str::PlSmallStr; @@ -260,6 +261,8 @@ impl DataFrame { /// Create a DataFrame from a Vector of Series. /// + /// Errors if a column names are not unique, or if heights are not all equal. + /// /// # Example /// /// ``` @@ -271,17 +274,9 @@ impl DataFrame { /// # Ok::<(), PolarsError>(()) /// ``` pub fn new(columns: Vec) -> PolarsResult { - ensure_names_unique(&columns, |s| s.name().as_str())?; - - let Some(fst) = columns.first() else { - return Ok(DataFrame { - height: 0, - columns, - cached_schema: OnceLock::new(), - }); - }; - - Self::new_with_height(fst.len(), columns) + DataFrame::validate_columns_slice(&columns) + .map_err(|e| e.wrap_msg(|e| format!("could not create a new DataFrame: {}", e)))?; + Ok(unsafe { Self::new_no_checks_height_from_first(columns) }) } pub fn new_with_height(height: usize, columns: Vec) -> PolarsResult { @@ -522,11 +517,7 @@ impl DataFrame { /// having an equal length and a unique name, if not this may panic down the line. pub unsafe fn new_no_checks(height: usize, columns: Vec) -> DataFrame { if cfg!(debug_assertions) { - ensure_names_unique(&columns, |s| s.name().as_str()).unwrap(); - - for col in &columns { - assert_eq!(col.len(), height); - } + DataFrame::validate_columns_slice(&columns).unwrap(); } unsafe { Self::_new_no_checks_impl(height, columns) } @@ -544,30 +535,6 @@ impl DataFrame { } } - /// Create a new `DataFrame` but does not check the length of the `Series`, - /// only check for duplicates. - /// - /// It is advised to use [DataFrame::new] in favor of this method. - /// - /// # Safety - /// - /// It is the callers responsibility to uphold the contract of all `Series` - /// having an equal length, if not this may panic down the line. - pub unsafe fn new_no_length_checks(columns: Vec) -> PolarsResult { - ensure_names_unique(&columns, |s| s.name().as_str())?; - - Ok(if cfg!(debug_assertions) { - Self::new(columns).unwrap() - } else { - let height = Self::infer_height(&columns); - DataFrame { - height, - columns, - cached_schema: OnceLock::new(), - } - }) - } - /// Shrink the capacity of this DataFrame to fit its length. pub fn shrink_to_fit(&mut self) { // Don't parallelize this. Memory overhead @@ -1845,7 +1812,9 @@ impl DataFrame { cols: &[PlSmallStr], schema: &Schema, ) -> PolarsResult> { - debug_ensure_matching_schema_names(schema, self.schema())?; + if cfg!(debug_assertions) { + ensure_matching_schema_names(schema, self.schema())?; + } cols.iter() .map(|name| { diff --git a/crates/polars-core/src/frame/validation.rs b/crates/polars-core/src/frame/validation.rs new file mode 100644 index 000000000000..bb54f6dfa257 --- /dev/null +++ b/crates/polars-core/src/frame/validation.rs @@ -0,0 +1,67 @@ +use polars_error::{polars_bail, PolarsResult}; +use polars_utils::aliases::{InitHashMaps, PlHashSet}; + +use super::column::Column; +use super::DataFrame; + +impl DataFrame { + /// Ensure all equal height and names are unique. + /// + /// An Ok() result indicates `columns` is a valid state for a DataFrame. + pub fn validate_columns_slice(columns: &[Column]) -> PolarsResult<()> { + if columns.len() <= 1 { + return Ok(()); + } + + if columns.len() <= 4 { + // Too small to be worth spawning a hashmap for, this is at most 6 comparisons. + for i in 0..columns.len() - 1 { + let name = columns[i].name(); + let height = columns[i].len(); + + for other in columns.iter().skip(i + 1) { + if other.name() == name { + polars_bail!(duplicate = name); + } + + if other.len() != height { + polars_bail!( + ShapeMismatch: + "height of column '{}' ({}) does not match height of column '{}' ({})", + other.name(), other.len(), name, height + ) + } + } + } + } else { + let first = &columns[0]; + + let first_len = first.len(); + let first_name = first.name(); + + let mut names = PlHashSet::with_capacity(columns.len()); + names.insert(first_name); + + for col in &columns[1..] { + let col_name = col.name(); + let col_len = col.len(); + + if col_len != first_len { + polars_bail!( + ShapeMismatch: + "height of column '{}' ({}) does not match height of column '{}' ({})", + col_name, col_len, first_name, first_len + ) + } + + if names.contains(col_name) { + polars_bail!(duplicate = col_name) + } + + names.insert(col_name); + } + } + + Ok(()) + } +} diff --git a/crates/polars-ops/src/frame/pivot/mod.rs b/crates/polars-ops/src/frame/pivot/mod.rs index 34d553a65db0..a65c06c7454a 100644 --- a/crates/polars-ops/src/frame/pivot/mod.rs +++ b/crates/polars-ops/src/frame/pivot/mod.rs @@ -379,6 +379,5 @@ fn pivot_impl_single_column( }); out?; - // SAFETY: length has already been checked. - unsafe { DataFrame::new_no_length_checks(final_cols) } + DataFrame::new(final_cols) } diff --git a/crates/polars-ops/src/series/ops/to_dummies.rs b/crates/polars-ops/src/series/ops/to_dummies.rs index eb2cf3a228c1..5737d35b0eae 100644 --- a/crates/polars-ops/src/series/ops/to_dummies.rs +++ b/crates/polars-ops/src/series/ops/to_dummies.rs @@ -46,8 +46,7 @@ impl ToDummies for Series { }) .collect::>(); - // SAFETY: `dummies_helper` functions preserve `self.len()` length - unsafe { DataFrame::new_no_length_checks(sort_columns(columns)) } + DataFrame::new(sort_columns(columns)) } } diff --git a/crates/polars-schema/src/schema.rs b/crates/polars-schema/src/schema.rs index 38dec96ccbb5..21abfc0ddbf3 100644 --- a/crates/polars-schema/src/schema.rs +++ b/crates/polars-schema/src/schema.rs @@ -457,18 +457,16 @@ where } } -pub fn debug_ensure_matching_schema_names(lhs: &Schema, rhs: &Schema) -> PolarsResult<()> { - if cfg!(debug_assertions) { - let lhs = lhs.iter_names().collect::>(); - let rhs = rhs.iter_names().collect::>(); - - if lhs != rhs { - polars_bail!( - SchemaMismatch: - "lhs: {:?} rhs: {:?}", - lhs, rhs - ) - } +pub fn ensure_matching_schema_names(lhs: &Schema, rhs: &Schema) -> PolarsResult<()> { + let lhs_names = lhs.iter_names(); + let rhs_names = rhs.iter_names(); + + if !(lhs_names.len() == rhs_names.len() && lhs_names.zip(rhs_names).all(|(l, r)| l == r)) { + polars_bail!( + SchemaMismatch: + "lhs: {:?} rhs: {:?}", + lhs.iter_names().collect::>(), rhs.iter_names().collect::>() + ) } Ok(()) diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 235fb0e44586..9e718b6bd8af 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -3009,7 +3009,7 @@ def test_get_column_index() -> None: def test_dataframe_creation_with_different_series_lengths_19795() -> None: with pytest.raises( ShapeError, - match='could not create a new DataFrame: series "a" has length 2 while series "b" has length 1', + match=r"could not create a new DataFrame: height of column 'b' \(1\) does not match height of column 'a' \(2\)", ): pl.DataFrame({"a": [1, 2], "b": [1]})