Skip to content

Commit

Permalink
fix new-streaming more
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Dec 6, 2024
1 parent 35f6440 commit 41ae898
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 21 deletions.
4 changes: 1 addition & 3 deletions crates/polars-io/src/parquet/read/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,7 @@ fn rg_to_dfs_prefiltered(
}

let mask_setting = PrefilterMaskSetting::init_from_env();

// let projected_schema = schema.try_project_indices(projection).unwrap();
let projected_schema = schema.clone();
let projected_schema = schema.try_project_indices(projection).unwrap();

let dfs: Vec<Option<DataFrame>> = POOL.install(move || {
// Set partitioned fields to prevent quadratic behavior.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ impl RowGroupDecoder {
// for `hive::merge_sorted_to_schema_order`.
let mut opt_decode_err = None;

let mut decoded_live_cols_iter = self
let decoded_live_cols_iter = self
.predicate_arrow_field_indices
.iter()
.map(|&i| self.projected_arrow_schema.get_at_index(i).unwrap())
Expand All @@ -512,26 +512,21 @@ impl RowGroupDecoder {
},
}
});
let mut hive_cols_iter = shared_file_state.hive_series.iter().map(|s| {
let hive_cols_iter = shared_file_state.hive_series.iter().map(|s| {
debug_assert!(s.len() >= projection_height);
s.slice(0, projection_height)
});

hive::merge_sorted_to_schema_order(
&mut decoded_live_cols_iter,
&mut hive_cols_iter,
&self.reader_schema,
&mut live_columns,
);

live_columns.extend(decoded_live_cols_iter);
live_columns.extend(hive_cols_iter);
opt_decode_err.transpose()?;

if let Some(file_path_series) = &shared_file_state.file_path_series {
debug_assert!(file_path_series.len() >= projection_height);
live_columns.push(file_path_series.slice(0, projection_height));
}

let live_df = unsafe {
let mut live_df = unsafe {
DataFrame::new_no_checks(row_group_data.row_group_metadata.num_rows(), live_columns)
};

Expand All @@ -542,6 +537,12 @@ impl RowGroupDecoder {
.evaluate_io(&live_df)?;
let mask = mask.bool().unwrap();

unsafe {
live_df.get_columns_mut().truncate(
self.row_index.is_some() as usize + self.predicate_arrow_field_indices.len(),
)
}

let filtered =
unsafe { filter_cols(live_df.take_columns(), mask, self.min_values_per_thread) }
.await?;
Expand All @@ -552,10 +553,38 @@ impl RowGroupDecoder {
mask.num_trues()
};

let live_df_filtered = unsafe { DataFrame::new_no_checks(height, filtered) };
let mut live_df_filtered = unsafe { DataFrame::new_no_checks(height, filtered) };

let projection_height = height;

if self.non_predicate_arrow_field_indices.is_empty() {
// User or test may have explicitly requested prefiltering

hive::merge_sorted_to_schema_order(
unsafe {
&mut live_df_filtered
.get_columns_mut()
.drain(..)
.collect::<Vec<_>>()
.into_iter()
},
&mut shared_file_state
.hive_series
.iter()
.map(|s| s.slice(0, projection_height)),
&self.reader_schema,
unsafe { live_df_filtered.get_columns_mut() },
);

unsafe {
live_df_filtered.get_columns_mut().extend(
shared_file_state
.file_path_series
.as_ref()
.map(|c| c.slice(0, projection_height)),
)
}

return Ok(live_df_filtered);
}

Expand Down Expand Up @@ -621,13 +650,36 @@ impl RowGroupDecoder {
&mut live_columns
.into_iter()
.skip(self.row_index.is_some() as usize), // hive_columns
&self.reader_schema,
&self.projected_arrow_schema,
&mut merged,
);

opt_decode_err.transpose()?;

let df = unsafe { DataFrame::new_no_checks(expected_num_rows, merged) };
let mut out = Vec::with_capacity(
merged.len()
+ shared_file_state.hive_series.len()
+ shared_file_state.file_path_series.is_some() as usize,
);

hive::merge_sorted_to_schema_order(
&mut merged.into_iter(),
&mut shared_file_state
.hive_series
.iter()
.map(|s| s.slice(0, projection_height)),
&self.reader_schema,
&mut out,
);

out.extend(
shared_file_state
.file_path_series
.as_ref()
.map(|c| c.slice(0, projection_height)),
);

let df = unsafe { DataFrame::new_no_checks(expected_num_rows, out) };
Ok(df)
}
}
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-stream/src/physical_plan/lower_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,8 @@ fn build_select_node_with_ctx(

if let Some(columns) = all_simple_columns {
let input_schema = ctx.phys_sm[input].output_schema.clone();
if input_schema.len() == columns.len()
if !cfg!(debug_assertions)
&& input_schema.len() == columns.len()
&& input_schema.iter_names().zip(&columns).all(|(l, r)| l == r)
{
// Input node already has the correct schema, just pass through.
Expand Down
19 changes: 15 additions & 4 deletions py-polars/tests/unit/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2594,7 +2594,17 @@ def test_utf8_verification_with_slice_20174() -> None:
)


def test_parquet_prefiltered_unordered_projection_20175() -> None:
@pytest.mark.parametrize("parallel", ["prefiltered", "row_groups"])
@pytest.mark.parametrize(
"projection",
[
{"a": pl.Int64(), "b": pl.Int64()},
{"b": pl.Int64(), "a": pl.Int64()},
],
)
def test_parquet_prefiltered_unordered_projection_20175(
parallel: str, projection: dict[str, pl.DataType]
) -> None:
df = pl.DataFrame(
[
pl.Series("a", [0], pl.Int64),
Expand All @@ -2607,9 +2617,10 @@ def test_parquet_prefiltered_unordered_projection_20175() -> None:

f.seek(0)
out = (
pl.scan_parquet(f, parallel="prefiltered")
pl.scan_parquet(f, parallel=parallel) # type: ignore[arg-type]
.filter(pl.col.a >= 0)
.select(["b", "a"])
.select(*projection.keys())
.collect()
)
assert out.schema == pl.Schema({"b": pl.Int64, "a": pl.Int64})

assert out.schema == projection

0 comments on commit 41ae898

Please sign in to comment.