Skip to content

Commit

Permalink
fix: Prefiltered optional plain primitive kernel (#21381)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Feb 21, 2025
1 parent 25d83c6 commit a379e68
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ fn decode_masked_optional<B: AlignedBytes>(
let mut num_rows_left = num_rows;
let mut value_offset = 0;

let mut iter = |mut f: u64, mut v: u64, len: usize| {
let mut iter = |mut f: u64, mut v: u64| {
if num_rows_left == 0 {
return false;
}
Expand Down Expand Up @@ -483,25 +483,23 @@ fn decode_masked_optional<B: AlignedBytes>(
unsafe {
target_ptr = target_ptr.add(num_written);
}
value_offset += len;
value_offset += num_chunk_values;
num_rows_left -= num_written;
num_values_left -= num_chunk_values;

true
};

for (f, v) in mask_iter.by_ref().zip(validity_iter.by_ref()) {
if !iter(f, v, 56) {
if !iter(f, v) {
break;
}
}

let (f, fl) = mask_iter.remainder();
let (v, vl) = validity_iter.remainder();

assert_eq!(fl, vl);

iter(f, v, fl);
iter(f, v);

unsafe { target.set_len(start_length + num_rows) };

Expand Down
83 changes: 83 additions & 0 deletions py-polars/tests/unit/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2912,3 +2912,86 @@ def test_nested_deprecated_int96_timestamps_21332() -> None:
pl.read_parquet(f),
df,
)


def test_final_masked_optional_iteration_21378() -> None:
# fmt: off
values = [
1, 0, 0, 0, 0, 1, 1, 1,
1, 0, 0, 1, 1, 1, 1, 0,
0, 1, 1, 1, 0, 1, 0, 0,
1, 1, 0, 0, 0, 1, 1, 1,
0, 1, 0, 0, 1, 1, 1, 1,
0, 1, 1, 1, 0, 1, 0, 1,
0, 1, 1, 0, 1, 0, 1, 1,
0, 0, 0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 0, 0, 1, 1,
0, 0, 1, 1, 0, 0, 0, 1,
1, 1, 0, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 0,
0, 0, 1, 0, 1, 1, 0, 0,
0, 1, 1, 0, 0, 1, 0, 0,
1, 1, 1, 1, 0, 0, 1, 0,
0, 1, 1, 0, 0, 1, 1, 1,
1, 1, 1, 0, 1, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1,
0, 0, 0, 1, 1, 0, 0, 0,
1, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 0, 0, 0, 0, 1,
0, 0, 1, 1, 0, 0, 1, 1,
0, 1, 0, 0, 0, 1, 1, 1,
1, 0, 1, 0, 1, 0, 1, 1,
1, 0, 1, 0, 0, 1, 0, 1,
0, 1, 1, 1, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1, 0, 0,
1, 0, 0, 0, 0, 0, 0, 1,
1, 1, 1, 0, 0, 0, 0, 0,
1, 1, 1, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 1, 0, 0, 0, 0, 1,
0, 0, 0, 0, 0, 1, 0, 0,
1, 0, 1, 0, 0, 1, 0, 0,
0, 1, 1, 1, 0, 0, 1, 1,
1, 0, 1, 1, 0, 0, 0, 1,
0, 0, 1, 1, 0, 1, 0, 1,
0, 1, 1, 1, 0, 0, 0, 1,
0, 0, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 1, 1,
1, 0, 1, 1, 1, 1, 1, 0,
1, 0, 1, 0, 0, 0, 1, 1,
0, 0, 0, 1, 0, 0, 1, 0,
0, 1, 0, 0, 1, 0, 1, 1,
1, 0, 0, 1, 0, 1, 1, 0,
0, 1, 0, 1, 1, 0, 1, 0,
0, 0, 0, 1, 1, 1, 0, 0,
0, 1, 0, 1, 1, 0, 1, 1,
1, 1, 0, 1, 0, 1, 0, 1,
1, 1, 0, 1, 0, 0, 1, 0,
1, 1, 0, 1, 1, 0, 0, 1,
0, 0, 0, 0, 0, 1, 0, 0,
0, 1, 0, 0, 1, 1, 1, 1,
1, 0, 1, 1, 1, 0, 1, 1,
1, 1, 0, 0, 0, 0, 1, 1,
]

df = pl.DataFrame(
[
pl.Series("x", [None if x == 1 else 0.0 for x in values], pl.Float32),
pl.Series(
"f",
[False] * 164 +
[True] * 10 +
[False] * 264 +
[True] * 10,
pl.Boolean(),
),
]
)

f = io.BytesIO()
df.write_parquet(f)
f.seek(0)

output = pl.scan_parquet(f, parallel="prefiltered").filter(pl.col.f).collect()
assert_frame_equal(df.filter(pl.col.f), output)

0 comments on commit a379e68

Please sign in to comment.