Skip to content

Commit

Permalink
feat: Conserve Parquet SortingColumns for ints
Browse files Browse the repository at this point in the history
This PR makes it so that `SortedColumns` can be used to preserve the sorted
flag when reading into Polars. Currently, this is only enabled for integers as
other types might require additional considerations. Enabling this feature for
other types is trivial now, however.

```rust
import polars as pl
import pyarrow.parquet as pq
import io

f = io.BytesIO()

df = pl.DataFrame({
    "a": [1, 2, 3, 4, 5, None],
    "b": [1.0, 2.0, 3.0, 4.0, 5.0, None],
    "c": range(6),
})

pq.write_table(
    df.to_arrow(),
    f,
    sorting_columns=[
        pq.SortingColumn(0, False, False),
        pq.SortingColumn(1, False, False),
    ],
)

f.seek(0)
df = pl.read_parquet(f)._to_metadata(stats='sorted_asc')
```

Before:

```console
shape: (3, 2)
┌─────────────┬────────────┐
│ column_name ┆ sorted_asc │
│ ---         ┆ ---        │
│ str         ┆ bool       │
╞═════════════╪════════════╡
│ a           ┆ false      │
│ b           ┆ false      │
│ c           ┆ false      │
└─────────────┴────────────┘
```

After:

```console
shape: (3, 2)
┌─────────────┬────────────┐
│ column_name ┆ sorted_asc │
│ ---         ┆ ---        │
│ str         ┆ bool       │
╞═════════════╪════════════╡
│ a           ┆ true       │
│ b           ┆ false      │
│ c           ┆ false      │
└─────────────┴────────────┘
```
  • Loading branch information
coastalwhite committed Oct 15, 2024
1 parent e29e9df commit 3cd0c62
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 16 deletions.
103 changes: 88 additions & 15 deletions crates/polars-io/src/parquet/read/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ use arrow::bitmap::MutableBitmap;
use arrow::datatypes::ArrowSchemaRef;
use polars_core::chunked_array::builder::NullChunkedBuilder;
use polars_core::prelude::*;
use polars_core::series::IsSorted;
use polars_core::utils::{accumulate_dataframes_vertical, split_df};
use polars_core::POOL;
use polars_core::{config, POOL};
use polars_parquet::parquet::error::ParquetResult;
use polars_parquet::parquet::statistics::Statistics;
use polars_parquet::read::{
Expand Down Expand Up @@ -60,6 +61,57 @@ fn assert_dtypes(dtype: &ArrowDataType) {
}
}

fn should_copy_sortedness(dtype: &DataType) -> bool {
// @NOTE: For now, we are a bit conservative with this.
use DataType as D;

matches!(
dtype,
D::Int8 | D::Int16 | D::Int32 | D::Int64 | D::UInt8 | D::UInt16 | D::UInt32 | D::UInt64
)
}

fn try_set_sorted_flag(
series: &mut Series,
col_idx: usize,
sorting_map: &PlHashMap<usize, IsSorted>,
) {
if let Some(is_sorted) = sorting_map.get(&col_idx) {
if should_copy_sortedness(series.dtype()) {
if config::verbose() {
eprintln!(
"Parquet conserved SortingColumn for column chunk of '{}' to {is_sorted:?}",
series.name()
);
}

series.set_sorted_flag(*is_sorted);
}
}
}

fn create_sorting_map(md: &RowGroupMetadata) -> PlHashMap<usize, IsSorted> {
let capacity = md.sorting_columns().map_or(0, |s| s.len());
let mut sorting_map = PlHashMap::with_capacity(capacity);

if let Some(sorting_columns) = md.sorting_columns() {
for sorting in sorting_columns {
let prev_value = sorting_map.insert(
sorting.column_idx as usize,
if sorting.descending {
IsSorted::Descending
} else {
IsSorted::Ascending
},
);

debug_assert!(prev_value.is_none());
}
}

sorting_map
}

fn column_idx_to_series(
column_i: usize,
// The metadata belonging to this column
Expand Down Expand Up @@ -320,6 +372,8 @@ fn rg_to_dfs_prefiltered(
}
}

let sorting_map = create_sorting_map(md);

// Collect the data for the live columns
let live_columns = (0..num_live_columns)
.into_par_iter()
Expand All @@ -338,8 +392,12 @@ fn rg_to_dfs_prefiltered(

let part = iter.collect::<Vec<_>>();

column_idx_to_series(col_idx, part.as_slice(), None, schema, store)
.map(Column::from)
let mut series =
column_idx_to_series(col_idx, part.as_slice(), None, schema, store)?;

try_set_sorted_flag(&mut series, col_idx, &sorting_map);

Ok(series.into_column())
})
.collect::<PolarsResult<Vec<_>>>()?;

Expand Down Expand Up @@ -445,7 +503,7 @@ fn rg_to_dfs_prefiltered(
array.filter(&mask_arr)
};

let array = if mask_setting.should_prefilter(
let mut series = if mask_setting.should_prefilter(
prefilter_cost,
&schema.get_at_index(col_idx).unwrap().1.dtype,
) {
Expand All @@ -454,9 +512,11 @@ fn rg_to_dfs_prefiltered(
post()?
};

debug_assert_eq!(array.len(), filter_mask.set_bits());
debug_assert_eq!(series.len(), filter_mask.set_bits());

try_set_sorted_flag(&mut series, col_idx, &sorting_map);

Ok(array.into_column())
Ok(series.into_column())
})
.collect::<PolarsResult<Vec<Column>>>()?;

Expand Down Expand Up @@ -569,6 +629,8 @@ fn rg_to_dfs_optionally_par_over_columns(
assert!(std::env::var("POLARS_PANIC_IF_PARQUET_PARSED").is_err())
}

let sorting_map = create_sorting_map(md);

let columns = if let ParallelStrategy::Columns = parallel {
POOL.install(|| {
projection
Expand All @@ -586,14 +648,17 @@ fn rg_to_dfs_optionally_par_over_columns(

let part = iter.collect::<Vec<_>>();

column_idx_to_series(
let mut series = column_idx_to_series(
*column_i,
part.as_slice(),
Some(Filter::new_ranged(rg_slice.0, rg_slice.0 + rg_slice.1)),
schema,
store,
)
.map(Column::from)
)?;

try_set_sorted_flag(&mut series, *column_i, &sorting_map);

Ok(series.into_column())
})
.collect::<PolarsResult<Vec<_>>>()
})?
Expand All @@ -613,14 +678,17 @@ fn rg_to_dfs_optionally_par_over_columns(

let part = iter.collect::<Vec<_>>();

column_idx_to_series(
let mut series = column_idx_to_series(
*column_i,
part.as_slice(),
Some(Filter::new_ranged(rg_slice.0, rg_slice.0 + rg_slice.1)),
schema,
store,
)
.map(Column::from)
)?;

try_set_sorted_flag(&mut series, *column_i, &sorting_map);

Ok(series.into_column())
})
.collect::<PolarsResult<Vec<_>>>()?
};
Expand Down Expand Up @@ -705,6 +773,8 @@ fn rg_to_dfs_par_over_rg(
assert!(std::env::var("POLARS_PANIC_IF_PARQUET_PARSED").is_err())
}

let sorting_map = create_sorting_map(md);

let columns = projection
.iter()
.map(|column_i| {
Expand All @@ -720,14 +790,17 @@ fn rg_to_dfs_par_over_rg(

let part = iter.collect::<Vec<_>>();

column_idx_to_series(
let mut series = column_idx_to_series(
*column_i,
part.as_slice(),
Some(Filter::new_ranged(slice.0, slice.0 + slice.1)),
schema,
store,
)
.map(Column::from)
)?;

try_set_sorted_flag(&mut series, *column_i, &sorting_map);

Ok(series.into_column())
})
.collect::<PolarsResult<Vec<_>>>()?;

Expand Down
10 changes: 9 additions & 1 deletion crates/polars-parquet/src/parquet/metadata/row_metadata.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use hashbrown::hash_map::RawEntryMut;
use parquet_format_safe::RowGroup;
use parquet_format_safe::{RowGroup, SortingColumn};
use polars_utils::aliases::{InitHashMaps, PlHashMap};
use polars_utils::idx_vec::UnitVec;
use polars_utils::pl_str::PlSmallStr;
Expand Down Expand Up @@ -41,6 +41,7 @@ pub struct RowGroupMetadata {
num_rows: usize,
total_byte_size: usize,
full_byte_range: core::ops::Range<u64>,
sorting_columns: Option<Vec<SortingColumn>>,
}

impl RowGroupMetadata {
Expand Down Expand Up @@ -85,6 +86,10 @@ impl RowGroupMetadata {
self.columns.iter().map(|x| x.byte_range())
}

pub fn sorting_columns(&self) -> Option<&[SortingColumn]> {
self.sorting_columns.as_deref()
}

/// Method to convert from Thrift.
pub(crate) fn try_from_thrift(
schema_descr: &SchemaDescriptor,
Expand All @@ -106,6 +111,8 @@ impl RowGroupMetadata {
0..0
};

let sorting_columns = rg.sorting_columns.clone();

let columns = rg
.columns
.into_iter()
Expand All @@ -131,6 +138,7 @@ impl RowGroupMetadata {
num_rows,
total_byte_size,
full_byte_range,
sorting_columns,
})
}
}
45 changes: 45 additions & 0 deletions py-polars/tests/unit/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1990,3 +1990,48 @@ def test_nested_nonnullable_19158() -> None:

f.seek(0)
assert_frame_equal(pl.read_parquet(f), pl.DataFrame(tbl))


@pytest.mark.parametrize("parallel", ["prefiltered", "columns", "row_groups", "auto"])
def test_conserve_sortedness(parallel: pl.ParallelStrategy) -> None:
f = io.BytesIO()

df = pl.DataFrame(
{
"a": [1, 2, 3, 4, 5, None],
"b": [1.0, 2.0, 3.0, 4.0, 5.0, None],
"c": [None, 5, 4, 3, 2, 1],
"d": [None, 5.0, 4.0, 3.0, 2.0, 1.0],
"a_nosort": [1, 2, 3, 4, 5, None],
"f": range(6),
}
)

pq.write_table(
df.to_arrow(),
f,
sorting_columns=[
pq.SortingColumn(0, False, False),
pq.SortingColumn(1, False, False),
pq.SortingColumn(2, True, True),
pq.SortingColumn(3, True, True),
],
)

f.seek(0)
df = pl.scan_parquet(f, parallel=parallel).filter(pl.col.f > 1).collect()

cols = ["a", "b", "c", "d", "a_nosort"]

# @NOTE: We don't conserve sortedness for anything except integers at the
# moment.
assert_frame_equal(
df._to_metadata(cols, ["sorted_asc", "sorted_dsc"]),
pl.DataFrame(
{
"column_name": cols,
"sorted_asc": [True, False, False, False, False],
"sorted_dsc": [False, False, True, False, False],
}
),
)

0 comments on commit 3cd0c62

Please sign in to comment.