From a379e6837ad9a57aaa4285f5a7e8e6d75383586b Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Fri, 21 Feb 2025 11:28:19 +0100 Subject: [PATCH] fix: Prefiltered optional plain primitive kernel (#21381) --- .../read/deserialize/primitive/plain/mod.rs | 10 +-- py-polars/tests/unit/io/test_parquet.py | 83 +++++++++++++++++++ 2 files changed, 87 insertions(+), 6 deletions(-) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain/mod.rs index b11d292d1aca..3a4e5fcb71b7 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain/mod.rs @@ -431,7 +431,7 @@ fn decode_masked_optional( let mut num_rows_left = num_rows; let mut value_offset = 0; - let mut iter = |mut f: u64, mut v: u64, len: usize| { + let mut iter = |mut f: u64, mut v: u64| { if num_rows_left == 0 { return false; } @@ -483,7 +483,7 @@ fn decode_masked_optional( unsafe { target_ptr = target_ptr.add(num_written); } - value_offset += len; + value_offset += num_chunk_values; num_rows_left -= num_written; num_values_left -= num_chunk_values; @@ -491,17 +491,15 @@ fn decode_masked_optional( }; for (f, v) in mask_iter.by_ref().zip(validity_iter.by_ref()) { - if !iter(f, v, 56) { + if !iter(f, v) { break; } } let (f, fl) = mask_iter.remainder(); let (v, vl) = validity_iter.remainder(); - assert_eq!(fl, vl); - - iter(f, v, fl); + iter(f, v); unsafe { target.set_len(start_length + num_rows) }; diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index f401821497ec..958adc7f9394 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -2912,3 +2912,86 @@ def test_nested_deprecated_int96_timestamps_21332() -> None: pl.read_parquet(f), df, ) + + +def test_final_masked_optional_iteration_21378() -> None: + # fmt: off + values = [ + 1, 0, 0, 0, 0, 1, 1, 1, + 1, 0, 0, 1, 1, 1, 1, 0, + 0, 1, 1, 1, 0, 1, 0, 0, + 1, 1, 0, 0, 0, 1, 1, 1, + 0, 1, 0, 0, 1, 1, 1, 1, + 0, 1, 1, 1, 0, 1, 0, 1, + 0, 1, 1, 0, 1, 0, 1, 1, + 0, 0, 0, 0, 1, 0, 0, 0, + 0, 1, 1, 1, 0, 0, 1, 1, + 0, 0, 1, 1, 0, 0, 0, 1, + 1, 1, 0, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 0, + 0, 0, 1, 0, 1, 1, 0, 0, + 0, 1, 1, 0, 0, 1, 0, 0, + 1, 1, 1, 1, 0, 0, 1, 0, + 0, 1, 1, 0, 0, 1, 1, 1, + 1, 1, 1, 0, 1, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, + 0, 0, 0, 1, 1, 0, 0, 0, + 1, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 0, 0, 0, 0, 1, + 0, 0, 1, 1, 0, 0, 1, 1, + 0, 1, 0, 0, 0, 1, 1, 1, + 1, 0, 1, 0, 1, 0, 1, 1, + 1, 0, 1, 0, 0, 1, 0, 1, + 0, 1, 1, 1, 0, 0, 0, 1, + 1, 1, 1, 1, 1, 1, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 1, + 1, 1, 1, 0, 0, 0, 0, 0, + 1, 1, 1, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 1, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 1, 0, 0, + 1, 0, 1, 0, 0, 1, 0, 0, + 0, 1, 1, 1, 0, 0, 1, 1, + 1, 0, 1, 1, 0, 0, 0, 1, + 0, 0, 1, 1, 0, 1, 0, 1, + 0, 1, 1, 1, 0, 0, 0, 1, + 0, 0, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 1, 1, + 1, 0, 1, 1, 1, 1, 1, 0, + 1, 0, 1, 0, 0, 0, 1, 1, + 0, 0, 0, 1, 0, 0, 1, 0, + 0, 1, 0, 0, 1, 0, 1, 1, + 1, 0, 0, 1, 0, 1, 1, 0, + 0, 1, 0, 1, 1, 0, 1, 0, + 0, 0, 0, 1, 1, 1, 0, 0, + 0, 1, 0, 1, 1, 0, 1, 1, + 1, 1, 0, 1, 0, 1, 0, 1, + 1, 1, 0, 1, 0, 0, 1, 0, + 1, 1, 0, 1, 1, 0, 0, 1, + 0, 0, 0, 0, 0, 1, 0, 0, + 0, 1, 0, 0, 1, 1, 1, 1, + 1, 0, 1, 1, 1, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 1, 1, + ] + + df = pl.DataFrame( + [ + pl.Series("x", [None if x == 1 else 0.0 for x in values], pl.Float32), + pl.Series( + "f", + [False] * 164 + + [True] * 10 + + [False] * 264 + + [True] * 10, + pl.Boolean(), + ), + ] + ) + + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + + output = pl.scan_parquet(f, parallel="prefiltered").filter(pl.col.f).collect() + assert_frame_equal(df.filter(pl.col.f), output)