Skip to content

Commit

Permalink
fix(rust): Fix height validation in hstack_mut was bypassed when ad…
Browse files Browse the repository at this point in the history
…ding to empty frame (#21335)
  • Loading branch information
nameexhaustion authored Feb 20, 2025
1 parent c0f345f commit 9164d2c
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 131 deletions.
148 changes: 76 additions & 72 deletions crates/polars-core/src/frame/horizontal.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,8 @@
use polars_error::{polars_ensure, polars_err, PolarsResult};
use polars_utils::aliases::PlHashSet;
use polars_error::{polars_err, PolarsResult};

use super::Column;
use crate::datatypes::AnyValue;
use crate::frame::DataFrame;
use crate::prelude::PlSmallStr;

fn check_hstack(
col: &Column,
names: &mut PlHashSet<PlSmallStr>,
height: usize,
is_empty: bool,
) -> PolarsResult<()> {
polars_ensure!(
col.len() == height || is_empty,
ShapeMismatch: "unable to hstack Series of length {} and DataFrame of height {}",
col.len(), height,
);
polars_ensure!(
names.insert(col.name().clone()),
Duplicate: "unable to hstack, column with name {:?} already exists", col.name().as_str(),
);
Ok(())
}

impl DataFrame {
/// Add columns horizontally.
Expand All @@ -31,28 +11,35 @@ impl DataFrame {
/// The caller must ensure:
/// - the length of all [`Column`] is equal to the height of this [`DataFrame`]
/// - the columns names are unique
///
/// Note: If `self` is empty, `self.height` will always be overridden by the height of the first
/// column in `columns`.
///
/// Note that on a debug build this will panic on duplicates / height mismatch.
pub unsafe fn hstack_mut_unchecked(&mut self, columns: &[Column]) -> &mut Self {
// If we don't have any columns yet, copy the height from the given columns.
if let Some(fst) = columns.first() {
if self.width() == 0 {
// SAFETY: The functions invariants asks for all columns to be the same length so
// that makes that a valid height.
unsafe { self.set_height(fst.len()) };
self.clear_schema();
self.columns.extend_from_slice(columns);

if cfg!(debug_assertions) {
if let err @ Err(_) = DataFrame::validate_columns_slice(&self.columns) {
// Reset DataFrame state to before extend.
self.columns.truncate(self.columns.len() - columns.len());
err.unwrap();
}
}

if cfg!(debug_assertions) {
// It is an impl error if this fails.
self._validate_hstack(columns).unwrap();
if let Some(c) = self.columns.first() {
unsafe { self.set_height(c.len()) };
}

self.clear_schema();
self.columns.extend_from_slice(columns);
self
}

/// Add multiple [`Column`] to a [`DataFrame`].
/// The added `Series` are required to have the same length.
/// Errors if the resulting DataFrame columns have duplicate names or unequal heights.
///
/// Note: If `self` is empty, `self.height` will always be overridden by the height of the first
/// column in `columns`.
///
/// # Example
///
Expand All @@ -63,28 +50,23 @@ impl DataFrame {
/// }
/// ```
pub fn hstack_mut(&mut self, columns: &[Column]) -> PolarsResult<&mut Self> {
self._validate_hstack(columns)?;
Ok(unsafe { self.hstack_mut_unchecked(columns) })
}
self.clear_schema();
self.columns.extend_from_slice(columns);

fn _validate_hstack(&self, columns: &[Column]) -> PolarsResult<()> {
let mut names = self
.columns
.iter()
.map(|c| c.name().clone())
.collect::<PlHashSet<_>>();

let height = self.height();
let is_empty = self.is_empty();
// first loop check validity. We don't do this in a single pass otherwise
// this DataFrame is already modified when an error occurs.
for col in columns {
check_hstack(col, &mut names, height, is_empty)?;
if let err @ Err(_) = DataFrame::validate_columns_slice(&self.columns) {
// Reset DataFrame state to before extend.
self.columns.truncate(self.columns.len() - columns.len());
err?;
}

Ok(())
if let Some(c) = self.columns.first() {
unsafe { self.set_height(c.len()) };
}

Ok(self)
}
}

/// Concat [`DataFrame`]s horizontally.
/// Concat horizontally and extend with null values if lengths don't match
pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> PolarsResult<DataFrame> {
Expand All @@ -96,12 +78,23 @@ pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> Polars

let owned_df;

let mut out_width = 0;

let all_equal_height = dfs.iter().all(|df| {
out_width += df.width();
df.height() == output_height
});

// if not all equal length, extend the DataFrame with nulls
let dfs = if !dfs.iter().all(|df| df.height() == output_height) {
let dfs = if !all_equal_height {
out_width = 0;

owned_df = dfs
.iter()
.cloned()
.map(|mut df| {
out_width += df.width();

if df.height() != output_height {
let diff = output_height - df.height();

Expand All @@ -123,30 +116,41 @@ pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> Polars
dfs
};

let mut first_df = dfs[0].clone();
let height = first_df.height();
let is_empty = first_df.is_empty();
let mut acc_cols = Vec::with_capacity(out_width);

let mut names = if check_duplicates {
first_df
.columns
.iter()
.map(|s| s.name().clone())
.collect::<PlHashSet<_>>()
} else {
Default::default()
};
for df in dfs {
acc_cols.extend(df.get_columns().iter().cloned());
}

for df in &dfs[1..] {
let cols = df.get_columns();
if check_duplicates {
DataFrame::validate_columns_slice(&acc_cols)?;
}

if check_duplicates {
for col in cols {
check_hstack(col, &mut names, height, is_empty)?;
}
}
let df = unsafe { DataFrame::new_no_checks_height_from_first(acc_cols) };

Ok(df)
}

unsafe { first_df.hstack_mut_unchecked(cols) };
#[cfg(test)]
mod tests {
use polars_error::PolarsError;

#[test]
fn test_hstack_mut_empty_frame_height_validation() {
use crate::frame::DataFrame;
use crate::prelude::{Column, DataType};
let mut df = DataFrame::empty();
let result = df.hstack_mut(&[
Column::full_null("a".into(), 1, &DataType::Null),
Column::full_null("b".into(), 3, &DataType::Null),
]);

assert!(
matches!(result, Err(PolarsError::ShapeMismatch(_))),
"expected shape mismatch error"
);

// Ensure the DataFrame is not mutated in the error case.
assert_eq!(df.width(), 0);
}
Ok(first_df)
}
53 changes: 11 additions & 42 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{mem, ops};

use arrow::datatypes::ArrowSchemaRef;
use polars_row::ArrayRef;
use polars_schema::schema::debug_ensure_matching_schema_names;
use polars_schema::schema::ensure_matching_schema_names;
use polars_utils::itertools::Itertools;
use rayon::prelude::*;

Expand All @@ -31,6 +31,7 @@ pub(crate) mod horizontal;
pub mod row;
mod top_k;
mod upstream_traits;
mod validation;

use arrow::record_batch::{RecordBatch, RecordBatchT};
use polars_utils::pl_str::PlSmallStr;
Expand Down Expand Up @@ -260,6 +261,8 @@ impl DataFrame {

/// Create a DataFrame from a Vector of Series.
///
/// Errors if a column names are not unique, or if heights are not all equal.
///
/// # Example
///
/// ```
Expand All @@ -271,17 +274,9 @@ impl DataFrame {
/// # Ok::<(), PolarsError>(())
/// ```
pub fn new(columns: Vec<Column>) -> PolarsResult<Self> {
ensure_names_unique(&columns, |s| s.name().as_str())?;

let Some(fst) = columns.first() else {
return Ok(DataFrame {
height: 0,
columns,
cached_schema: OnceLock::new(),
});
};

Self::new_with_height(fst.len(), columns)
DataFrame::validate_columns_slice(&columns)
.map_err(|e| e.wrap_msg(|e| format!("could not create a new DataFrame: {}", e)))?;
Ok(unsafe { Self::new_no_checks_height_from_first(columns) })
}

pub fn new_with_height(height: usize, columns: Vec<Column>) -> PolarsResult<Self> {
Expand Down Expand Up @@ -522,11 +517,7 @@ impl DataFrame {
/// having an equal length and a unique name, if not this may panic down the line.
pub unsafe fn new_no_checks(height: usize, columns: Vec<Column>) -> DataFrame {
if cfg!(debug_assertions) {
ensure_names_unique(&columns, |s| s.name().as_str()).unwrap();

for col in &columns {
assert_eq!(col.len(), height);
}
DataFrame::validate_columns_slice(&columns).unwrap();
}

unsafe { Self::_new_no_checks_impl(height, columns) }
Expand All @@ -544,30 +535,6 @@ impl DataFrame {
}
}

/// Create a new `DataFrame` but does not check the length of the `Series`,
/// only check for duplicates.
///
/// It is advised to use [DataFrame::new] in favor of this method.
///
/// # Safety
///
/// It is the callers responsibility to uphold the contract of all `Series`
/// having an equal length, if not this may panic down the line.
pub unsafe fn new_no_length_checks(columns: Vec<Column>) -> PolarsResult<DataFrame> {
ensure_names_unique(&columns, |s| s.name().as_str())?;

Ok(if cfg!(debug_assertions) {
Self::new(columns).unwrap()
} else {
let height = Self::infer_height(&columns);
DataFrame {
height,
columns,
cached_schema: OnceLock::new(),
}
})
}

/// Shrink the capacity of this DataFrame to fit its length.
pub fn shrink_to_fit(&mut self) {
// Don't parallelize this. Memory overhead
Expand Down Expand Up @@ -1845,7 +1812,9 @@ impl DataFrame {
cols: &[PlSmallStr],
schema: &Schema,
) -> PolarsResult<Vec<Column>> {
debug_ensure_matching_schema_names(schema, self.schema())?;
if cfg!(debug_assertions) {
ensure_matching_schema_names(schema, self.schema())?;
}

cols.iter()
.map(|name| {
Expand Down
67 changes: 67 additions & 0 deletions crates/polars-core/src/frame/validation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use polars_error::{polars_bail, PolarsResult};
use polars_utils::aliases::{InitHashMaps, PlHashSet};

use super::column::Column;
use super::DataFrame;

impl DataFrame {
/// Ensure all equal height and names are unique.
///
/// An Ok() result indicates `columns` is a valid state for a DataFrame.
pub fn validate_columns_slice(columns: &[Column]) -> PolarsResult<()> {
if columns.len() <= 1 {
return Ok(());
}

if columns.len() <= 4 {
// Too small to be worth spawning a hashmap for, this is at most 6 comparisons.
for i in 0..columns.len() - 1 {
let name = columns[i].name();
let height = columns[i].len();

for other in columns.iter().skip(i + 1) {
if other.name() == name {
polars_bail!(duplicate = name);
}

if other.len() != height {
polars_bail!(
ShapeMismatch:
"height of column '{}' ({}) does not match height of column '{}' ({})",
other.name(), other.len(), name, height
)
}
}
}
} else {
let first = &columns[0];

let first_len = first.len();
let first_name = first.name();

let mut names = PlHashSet::with_capacity(columns.len());
names.insert(first_name);

for col in &columns[1..] {
let col_name = col.name();
let col_len = col.len();

if col_len != first_len {
polars_bail!(
ShapeMismatch:
"height of column '{}' ({}) does not match height of column '{}' ({})",
col_name, col_len, first_name, first_len
)
}

if names.contains(col_name) {
polars_bail!(duplicate = col_name)
}

names.insert(col_name);
}
}

Ok(())
}
}
3 changes: 1 addition & 2 deletions crates/polars-ops/src/frame/pivot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,5 @@ fn pivot_impl_single_column(
});
out?;

// SAFETY: length has already been checked.
unsafe { DataFrame::new_no_length_checks(final_cols) }
DataFrame::new(final_cols)
}
3 changes: 1 addition & 2 deletions crates/polars-ops/src/series/ops/to_dummies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ impl ToDummies for Series {
})
.collect::<Vec<_>>();

// SAFETY: `dummies_helper` functions preserve `self.len()` length
unsafe { DataFrame::new_no_length_checks(sort_columns(columns)) }
DataFrame::new(sort_columns(columns))
}
}

Expand Down
Loading

0 comments on commit 9164d2c

Please sign in to comment.