From e143e7ee0b0d44b6539a2644a9ea7364666d65fa Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Tue, 11 Feb 2025 12:52:18 +0100 Subject: [PATCH] feat: Add row index to new streaming multiscan (#21169) --- crates/polars-error/src/lib.rs | 8 + crates/polars-plan/src/plans/schema.rs | 5 + .../polars-stream/src/nodes/io_sources/csv.rs | 69 +++++++-- .../polars-stream/src/nodes/io_sources/ipc.rs | 60 ++++---- .../polars-stream/src/nodes/io_sources/mod.rs | 4 +- .../src/nodes/io_sources/multi_scan.rs | 137 +++++++++++++----- .../src/nodes/io_sources/parquet/mod.rs | 58 +++++--- .../src/physical_plan/lower_ir.rs | 35 ++--- crates/polars-stream/src/physical_plan/mod.rs | 5 +- .../src/physical_plan/to_graph.rs | 4 + py-polars/tests/unit/io/test_multiscan.py | 72 +++++++++ 11 files changed, 335 insertions(+), 122 deletions(-) diff --git a/crates/polars-error/src/lib.rs b/crates/polars-error/src/lib.rs index 6bd3bedefe40..5b6726a94581 100644 --- a/crates/polars-error/src/lib.rs +++ b/crates/polars-error/src/lib.rs @@ -338,6 +338,14 @@ macro_rules! polars_err { (opq = $op:ident, $lhs:expr, $rhs:expr) => { $crate::polars_err!(op = stringify!($op), $lhs, $rhs) }; + (bigidx, ctx = $ctx:expr, size = $size:expr) => { + polars_err!(ComputeError: "\ +{} produces {} rows which is more than maximum allowed pow(2, 32) rows; \ +consider compiling with bigidx feature (polars-u64-idx package on python)", + $ctx, + $size, + ) + }; (append) => { polars_err!(SchemaMismatch: "cannot append series, data types don't match") }; diff --git a/crates/polars-plan/src/plans/schema.rs b/crates/polars-plan/src/plans/schema.rs index 0dc6b17474d4..660975f721a6 100644 --- a/crates/polars-plan/src/plans/schema.rs +++ b/crates/polars-plan/src/plans/schema.rs @@ -32,6 +32,11 @@ impl DslPlan { #[derive(Clone, Debug, Default)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct FileInfo { + /// Schema of the physical file. + /// + /// Notes: + /// - Does not include logical columns like `include_file_path` and row index. + /// - Always includes all hive columns. pub schema: SchemaRef, /// Stores the schema used for the reader, as the main schema can contain /// extra hive columns. diff --git a/crates/polars-stream/src/nodes/io_sources/csv.rs b/crates/polars-stream/src/nodes/io_sources/csv.rs index 3278b4e7186b..fb52ffb95ab6 100644 --- a/crates/polars-stream/src/nodes/io_sources/csv.rs +++ b/crates/polars-stream/src/nodes/io_sources/csv.rs @@ -1,3 +1,4 @@ +use std::sync::atomic::Ordering; use std::sync::Arc; use futures::stream::FuturesUnordered; @@ -9,7 +10,7 @@ use polars_core::schema::{SchemaExt, SchemaRef}; use polars_core::utils::arrow::bitmap::Bitmap; #[cfg(feature = "dtype-categorical")] use polars_core::StringCacheHolder; -use polars_error::{polars_bail, PolarsResult}; +use polars_error::{polars_bail, polars_err, PolarsResult}; use polars_io::cloud::CloudOptions; use polars_io::prelude::_csv_read_internal::{ cast_columns, find_starting_point, prepare_csv_schema, read_chunk, CountLines, @@ -22,6 +23,7 @@ use polars_io::utils::slice::SplitSlicePosition; use polars_io::RowIndex; use polars_plan::plans::{csv_file_info, FileInfo, ScanSource, ScanSources}; use polars_plan::prelude::FileScanOptions; +use polars_utils::index::AtomicIdxSize; use polars_utils::mmap::MemSlice; use polars_utils::pl_str::PlSmallStr; use polars_utils::IdxSize; @@ -96,7 +98,7 @@ impl SourceNode for CsvSourceNode { mut output_recv: Receiver, _state: &ExecutionState, join_handles: &mut Vec>>, - _unrestricted_row_count: Option, + unrestricted_row_count: Option>, ) { let (mut send_to, recv_from) = (0..num_pipelines) .map(|_| connector::()) @@ -106,7 +108,7 @@ impl SourceNode for CsvSourceNode { let source_token = SourceToken::new(); let (line_batch_receivers, chunk_reader, line_batch_source_task_handle) = - self.init_line_batch_source(num_pipelines); + self.init_line_batch_source(num_pipelines, unrestricted_row_count); join_handles.extend(line_batch_receivers.into_iter().zip(recv_from).map( |(mut line_batch_rx, mut recv_from)| { @@ -202,7 +204,11 @@ impl SourceNode for CsvSourceNode { } impl CsvSourceNode { - fn init_line_batch_source(&mut self, num_pipelines: usize) -> AsyncTaskData { + fn init_line_batch_source( + &mut self, + num_pipelines: usize, + unrestricted_row_count: Option>, + ) -> AsyncTaskData { let verbose = self.verbose; let (mut line_batch_senders, line_batch_receivers): (Vec<_>, Vec<_>) = @@ -227,6 +233,9 @@ impl CsvSourceNode { let global_slice = self.file_options.slice; let include_file_paths = self.file_options.include_file_paths.is_some(); + // We don't deal with this yet for unrestricted_row_count. + assert!(unrestricted_row_count.is_none() || global_slice.is_none()); + if verbose { eprintln!( "[CsvSource]: slice: {:?}, row_index: {:?}", @@ -423,6 +432,13 @@ impl CsvSourceNode { } } + if let Some(unrestricted_row_count) = unrestricted_row_count.as_ref() { + let num_rows = *current_row_offset_ref; + let num_rows = IdxSize::try_from(num_rows) + .map_err(|_| polars_err!(bigidx, ctx = "csv file", size = num_rows))?; + unrestricted_row_count.store(num_rows, Ordering::Relaxed); + } + Ok(()) }), ); @@ -604,16 +620,20 @@ impl MultiScanable for CsvSourceNode { const DOES_PRED_PD: bool = false; const DOES_SLICE_PD: bool = true; - const DOES_ROW_INDEX: bool = false; async fn new( source: ScanSource, options: &Self::ReadOptions, cloud_options: Option<&CloudOptions>, + row_index: Option, ) -> PolarsResult { let sources = source.into_sources(); - let file_options = FileScanOptions::default(); + let file_options = FileScanOptions { + row_index: row_index.map(|name| RowIndex { name, offset: 0 }), + ..Default::default() + }; + let mut csv_options = options.clone(); let file_info = csv_file_info(&sources, &file_options, &mut csv_options, cloud_options)?; @@ -631,15 +651,38 @@ impl MultiScanable for CsvSourceNode { _ = row_restriction; todo!() } - fn with_row_index(&mut self, row_index: Option) { - _ = row_index; - todo!() - } - async fn row_count(&mut self) -> PolarsResult { - todo!() + async fn unrestricted_row_count(&mut self) -> PolarsResult { + let parse_options = self.options.get_parse_options(); + let source = self + .scan_sources + .at(0) + .to_memslice_async_assume_latest(true)?; + + let mem_slice = { + let mut out = vec![]; + maybe_decompress_bytes(&source, &mut out)?; + + if out.is_empty() { + source + } else { + MemSlice::from_vec(out) + } + }; + + let num_rows = polars_io::csv::read::count_rows_from_slice( + &mem_slice[..], + parse_options.separator, + parse_options.quote_char, + parse_options.comment_prefix.as_ref(), + parse_options.eol_char, + self.options.has_header, + )?; + let num_rows = IdxSize::try_from(num_rows) + .map_err(|_| polars_err!(bigidx, ctx = "csv file", size = num_rows))?; + Ok(num_rows) } - async fn schema(&mut self) -> PolarsResult { + async fn physical_schema(&mut self) -> PolarsResult { Ok(self.file_info.schema.clone()) } } diff --git a/crates/polars-stream/src/nodes/io_sources/ipc.rs b/crates/polars-stream/src/nodes/io_sources/ipc.rs index f56ea952f115..57b04854b22b 100644 --- a/crates/polars-stream/src/nodes/io_sources/ipc.rs +++ b/crates/polars-stream/src/nodes/io_sources/ipc.rs @@ -1,6 +1,7 @@ use std::cmp::Reverse; use std::io::Cursor; use std::ops::Range; +use std::sync::atomic::Ordering; use std::sync::Arc; use polars_core::config; @@ -15,7 +16,7 @@ use polars_core::utils::arrow::io::ipc::read::{ ProjectionInfo, }; use polars_core::utils::slice_offsets; -use polars_error::{ErrString, PolarsError, PolarsResult}; +use polars_error::{polars_err, ErrString, PolarsError, PolarsResult}; use polars_expr::state::ExecutionState; use polars_io::cloud::CloudOptions; use polars_io::ipc::IpcScanOptions; @@ -23,6 +24,7 @@ use polars_io::utils::columns_to_projection; use polars_io::RowIndex; use polars_plan::plans::{FileInfo, ScanSource, ScanSources}; use polars_plan::prelude::FileScanOptions; +use polars_utils::index::AtomicIdxSize; use polars_utils::mmap::MemSlice; use polars_utils::pl_str::PlSmallStr; use polars_utils::priority::Priority; @@ -51,6 +53,7 @@ pub struct IpcSourceNode { row_index: Option, slice: Range, + file_info: FileInfo, projection_info: Option, rechunk: bool, @@ -68,7 +71,7 @@ impl IpcSourceNode { #[allow(clippy::too_many_arguments)] pub fn new( sources: ScanSources, - _file_info: FileInfo, + file_info: FileInfo, options: IpcScanOptions, _cloud_options: Option, file_options: FileScanOptions, @@ -136,7 +139,9 @@ impl IpcSourceNode { slice, row_index, + projection_info, + file_info, rechunk, include_file_paths, @@ -186,7 +191,7 @@ impl SourceNode for IpcSourceNode { mut output_recv: Receiver, _state: &ExecutionState, join_handles: &mut Vec>>, - unrestricted_row_count: Option, + unrestricted_row_count: Option>, ) { // Split size for morsels. let max_morsel_size = get_max_morsel_size(); @@ -197,6 +202,7 @@ impl SourceNode for IpcSourceNode { row_index, slice, projection_info, + file_info: _, rechunk, include_file_paths, } = self; @@ -238,18 +244,6 @@ impl SourceNode for IpcSourceNode { let mut morsel = Morsel::new(df, seq, source_token.clone()); morsel.set_consume_token(wait_group.token()); - if let Some(rc) = unrestricted_row_count.as_ref() { - morsel = morsel.map(|mut df| { - df.with_column(Column::new_scalar( - rc.clone(), - Scalar::from(df.height() as IdxSize), - df.height(), - )) - .unwrap(); - df - }); - } - if sender.send(morsel).await.is_err() { return Ok(()); } @@ -298,6 +292,8 @@ impl SourceNode for IpcSourceNode { block_range, } = m; + // If we don't project any columns we cannot read properly from the file, + // so we just create an empty frame with the proper height. let mut df = if pl_schema.is_empty() { DataFrame::empty_with_height(slice.len()) } else { @@ -320,7 +316,6 @@ impl SourceNode for IpcSourceNode { df.try_extend(reader.by_ref().take(block_range.len()))?; (data_scratch, message_scratch) = reader.take_scratches(); - df = df.slice(slice.start as i64, slice.len()); if rechunk { @@ -372,6 +367,17 @@ impl SourceNode for IpcSourceNode { // // Walks all the sources and supplies block ranges to the decoder tasks. join_handles.push(spawn(TaskPriority::Low, async move { + // Calculate the unrestricted row count if needed. + if let Some(rc) = unrestricted_row_count { + let num_rows = get_row_count_from_blocks( + &mut std::io::Cursor::new(source.memslice.as_ref()), + &source.metadata.blocks, + )?; + let num_rows = IdxSize::try_from(num_rows) + .map_err(|_| polars_err!(bigidx, ctx = "ipc file", size = num_rows))?; + rc.store(num_rows, Ordering::Relaxed); + } + let mut morsel_seq: u64 = 0; let mut row_idx_offset: IdxSize = row_index.as_ref().map_or(0, |ri| ri.offset); let mut slice: Range = slice; @@ -508,12 +514,12 @@ impl MultiScanable for IpcSourceNode { const DOES_PRED_PD: bool = false; const DOES_SLICE_PD: bool = true; - const DOES_ROW_INDEX: bool = true; async fn new( source: ScanSource, options: &Self::ReadOptions, cloud_options: Option<&CloudOptions>, + row_index: Option, ) -> PolarsResult { let source = source.into_sources(); let options = options.clone(); @@ -525,10 +531,15 @@ impl MultiScanable for IpcSourceNode { let arrow_schema = metadata.schema.clone(); let schema = Schema::from_arrow_schema(arrow_schema.as_ref()); + let schema = Arc::new(schema); + + let mut file_options = FileScanOptions::default(); + if let Some(name) = row_index { + file_options.row_index = Some(RowIndex { name, offset: 0 }); + } - let file_options = FileScanOptions::default(); let file_info = FileInfo::new( - Arc::new(schema), + schema, Some(rayon::iter::Either::Left(arrow_schema)), (None, usize::MAX), ); @@ -558,20 +569,15 @@ impl MultiScanable for IpcSourceNode { } } } - fn with_row_index(&mut self, row_index: Option) { - self.row_index = row_index.map(|name| RowIndex { name, offset: 0 }); - } - async fn row_count(&mut self) -> PolarsResult { + async fn unrestricted_row_count(&mut self) -> PolarsResult { get_row_count_from_blocks( &mut std::io::Cursor::new(self.source.memslice.as_ref()), &self.source.metadata.blocks, ) .map(|v| v as IdxSize) } - async fn schema(&mut self) -> PolarsResult { - Ok(Arc::new(Schema::from_arrow_schema( - &self.source.metadata.schema, - ))) + async fn physical_schema(&mut self) -> PolarsResult { + Ok(self.file_info.schema.clone()) } } diff --git a/crates/polars-stream/src/nodes/io_sources/mod.rs b/crates/polars-stream/src/nodes/io_sources/mod.rs index 77ee209635a7..f29226ae7fe3 100644 --- a/crates/polars-stream/src/nodes/io_sources/mod.rs +++ b/crates/polars-stream/src/nodes/io_sources/mod.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use polars_error::PolarsResult; use polars_expr::state::ExecutionState; -use polars_utils::pl_str::PlSmallStr; +use polars_utils::index::AtomicIdxSize; use super::{ComputeNode, JoinHandle, Morsel, PortState, RecvPort, SendPort, TaskPriority}; use crate::async_primitives::connector::{connector, Receiver, Sender}; @@ -257,6 +257,6 @@ pub trait SourceNode: Sized + Send + Sync { output_recv: Receiver, state: &ExecutionState, join_handles: &mut Vec>>, - unrestricted_row_count: Option, + unrestricted_row_count: Option>, ); } diff --git a/crates/polars-stream/src/nodes/io_sources/multi_scan.rs b/crates/polars-stream/src/nodes/io_sources/multi_scan.rs index 1e490aab31eb..5c60ee7aa5f5 100644 --- a/crates/polars-stream/src/nodes/io_sources/multi_scan.rs +++ b/crates/polars-stream/src/nodes/io_sources/multi_scan.rs @@ -2,6 +2,7 @@ use std::cmp::Reverse; use std::future::Future; use std::marker::PhantomData; use std::ops::Range; +use std::sync::atomic::Ordering; use std::sync::Arc; use polars_core::frame::column::ScalarColumn; @@ -13,9 +14,11 @@ use polars_core::utils::arrow::bitmap::{Bitmap, MutableBitmap}; use polars_error::{polars_bail, PolarsResult}; use polars_expr::state::ExecutionState; use polars_io::cloud::CloudOptions; +use polars_io::RowIndex; use polars_mem_engine::ScanPredicate; use polars_plan::plans::hive::HivePartitions; use polars_plan::plans::{ScanSource, ScanSourceRef, ScanSources}; +use polars_utils::index::AtomicIdxSize; use polars_utils::pl_str::PlSmallStr; use polars_utils::priority::Priority; use polars_utils::{format_pl_smallstr, IdxSize}; @@ -58,6 +61,7 @@ pub struct MultiScanNode { file_schema: SchemaRef, projection: Option, + row_index: Option, read_options: Arc, cloud_options: Arc>, @@ -76,6 +80,7 @@ impl MultiScanNode { file_schema: SchemaRef, projection: Option, + row_index: Option, read_options: T::ReadOptions, cloud_options: Option, @@ -90,6 +95,7 @@ impl MultiScanNode { file_schema, projection, + row_index, read_options: Arc::new(read_options), cloud_options: Arc::new(cloud_options), @@ -99,6 +105,7 @@ impl MultiScanNode { } } +#[allow(clippy::too_many_arguments)] fn process_dataframe( mut df: DataFrame, source_name: &PlSmallStr, @@ -108,13 +115,27 @@ fn process_dataframe( file_schema: &Schema, projection: Option<&Bitmap>, + row_index: Option<&RowIndex>, ) -> PolarsResult { + _ = df.schema(); + + if let Some(ri) = row_index { + let ri_column = df + .get_column_index(ri.name.as_str()) + .expect("should have row index column here"); + + let columns = unsafe { df.get_columns_mut() }; + columns[ri_column] = (std::mem::take(&mut columns[ri_column]).take_materialized_series() + + Scalar::from(ri.offset).into_series(PlSmallStr::EMPTY))? + .into_column(); + } + if let Some(hive_part) = hive_part { let height = df.height(); if cfg!(debug_assertions) { - // We should have projected the hive column out when we read the file. let schema = df.schema(); + // We should have projected the hive column out when we read the file. for column in hive_part.get_statistics().column_stats().iter() { assert!(!schema.contains(column.field_name())); } @@ -237,6 +258,7 @@ fn resolve_source_projection( }); let mut source_projection = MutableBitmap::from_len_zeroed(source_schema.len()); + let mut j = 0; for (i, source_col_name) in source_schema.iter_names().enumerate() { while let Some((file_col_name, _)) = file_schema.get_at_index(j) { @@ -271,23 +293,29 @@ pub trait MultiScanable: SourceNode + Sized + Send + Sync { const DOES_PRED_PD: bool; const DOES_SLICE_PD: bool; - const DOES_ROW_INDEX: bool; fn new( source: ScanSource, options: &Self::ReadOptions, cloud_options: Option<&CloudOptions>, + row_index: Option, ) -> impl Future> + Send; + /// Provide a selection of physical columns to be loaded. + /// + /// The provided bitmap should have the same length as a schema given by + /// [`MultiScanable::physical_schema`]. fn with_projection(&mut self, projection: Option<&Bitmap>); - #[allow(unused)] + #[expect(unused)] fn with_row_restriction(&mut self, row_restriction: Option); - #[allow(unused)] - fn with_row_index(&mut self, row_index: Option); - #[allow(unused)] - fn row_count(&mut self) -> impl Future> + Send; - fn schema(&mut self) -> impl Future> + Send; + /// Get the number of physical rows in this source. + fn unrestricted_row_count(&mut self) -> impl Future> + Send; + + /// Schema inferred from of the source. + /// + /// This should **NOT** include any logical columns (e.g. file path, row index, hive columns). + fn physical_schema(&mut self) -> impl Future> + Send; } enum SourceInput { @@ -306,6 +334,18 @@ fn num_concurrent_scans(num_pipelines: usize) -> usize { num_pipelines.min(max_num_concurrent_scans) } +enum SourcePhaseContent { + /// 1+ columns + NonEmpty(SourceInput), + /// 0 columns + Empty(IdxSize), +} +struct SourcePhase { + content: SourcePhaseContent, + unrestricted_row_count: Option>, + missing_columns: Option, +} + impl SourceNode for MultiScanNode { fn name(&self) -> &str { &self.name @@ -321,7 +361,7 @@ impl SourceNode for MultiScanNode { mut send_port_recv: Receiver, _state: &ExecutionState, join_handles: &mut Vec>>, - unrestricted_row_count: Option, + unrestricted_row_count: Option>, ) { assert!(unrestricted_row_count.is_none()); @@ -331,21 +371,24 @@ impl SourceNode for MultiScanNode { let cloud_options = &self.cloud_options; let file_schema = &self.file_schema; let projection = &self.projection; + let row_index_name = self.row_index.as_ref().map(|ri| &ri.name); let allow_missing_columns = self.allow_missing_columns; let hive_schema = self .hive_parts .as_ref() .and_then(|p| Some(p.first()?.get_statistics().schema().clone())) .unwrap_or_else(|| Arc::new(Schema::default())); - let physical_columns: Bitmap = file_schema + let source_provided_columns: Bitmap = file_schema .iter_names() .map(|n| { - !hive_schema.contains(n) && self.include_file_paths.as_ref().is_none_or(|c| c != n) + !hive_schema.contains(n) + && self.include_file_paths.as_ref().is_none_or(|c| c != n) + && self.row_index.as_ref().is_none_or(|c| c.name != n) }) .collect(); let (si_send, mut si_recv) = (0..num_concurrent_scans) - .map(|_| connector::<(Result, Option)>()) + .map(|_| connector::()) .collect::<(Vec<_>, Vec<_>)>(); join_handles.extend(si_send.into_iter().enumerate().map(|(mut i, mut si_send)| { @@ -354,7 +397,8 @@ impl SourceNode for MultiScanNode { let cloud_options = cloud_options.clone(); let file_schema = file_schema.clone(); let projection = projection.clone(); - let physical_columns = physical_columns.clone(); + let row_index_name = row_index_name.cloned(); + let physical_columns = source_provided_columns.clone(); spawn(TaskPriority::High, async move { let state = ExecutionState::new(); @@ -368,10 +412,11 @@ impl SourceNode for MultiScanNode { source, read_options.as_ref(), cloud_options.as_ref().as_ref(), + row_index_name.clone(), ) .await?; - let source_schema = source.schema().await?; + let source_schema = source.physical_schema().await?; let (source_projection, missing_columns) = resolve_source_projection( file_schema.as_ref(), source_schema.as_ref(), @@ -384,29 +429,37 @@ impl SourceNode for MultiScanNode { // If we are not interested in any column, just count the rows and send that // back. - if source_projection.set_bits() == 0 { - let row_count = source.row_count().await?; + if row_index_name.is_none() && source_projection.set_bits() == 0 { + let row_count = source.unrestricted_row_count().await?; + let phase = SourcePhase { + content: SourcePhaseContent::Empty(row_count), + missing_columns: missing_columns.clone(), + unrestricted_row_count: row_index_name + .is_some() + .then(|| Arc::new(AtomicIdxSize::new(row_count))), + }; // Wait for the orchestrator task to actually be interested in the output // of this file. - if si_send - .send((Err(row_count), missing_columns.clone())) - .await - .is_err() - { + if si_send.send(phase).await.is_err() { break; }; i += num_concurrent_scans; continue; } + let unrestricted_row_count = row_index_name + .is_some() + .then(|| Arc::new(AtomicIdxSize::new(0))); + source.with_projection(Some(&source_projection)); + source.spawn_source( num_pipelines, output_recv, &state, &mut join_handles, - None, + unrestricted_row_count.clone(), ); // Loop until a phase result indicated that the source is empty. @@ -421,13 +474,15 @@ impl SourceNode for MultiScanNode { (SourceOutputPort::Serial(tx), SourceInput::Serial(rx)) }; + let phase = SourcePhase { + content: SourcePhaseContent::NonEmpty(rx), + missing_columns: missing_columns.clone(), + unrestricted_row_count: unrestricted_row_count.clone(), + }; + // Wait for the orchestrator task to actually be interested in the output // of this file. - if si_send - .send((Ok(rx), missing_columns.clone())) - .await - .is_err() - { + if si_send.send(phase).await.is_err() { break; }; @@ -481,6 +536,7 @@ impl SourceNode for MultiScanNode { let include_file_paths = self.include_file_paths.clone(); let file_schema = self.file_schema.clone(); let projection = self.projection.clone(); + let mut row_index = self.row_index.clone(); let sources = sources.clone(); join_handles.push(spawn(TaskPriority::High, async move { let mut seq = MorselSeq::default(); @@ -498,24 +554,25 @@ impl SourceNode for MultiScanNode { let source_name = source_name(sources.at(current_scan), current_scan); let hive_part = hive_parts.as_deref().map(|parts| &parts[current_scan]); let si_recv = &mut si_recv[current_scan % num_concurrent_scans]; - let Ok((rx, missing_columns)) = si_recv.recv().await else { + let Ok(phase) = si_recv.recv().await else { return Ok(()); }; - match rx { + match phase.content { // In certain cases, we don't actually need to read physical data from the // file so we get back a row count. - Err(row_count) => { + SourcePhaseContent::Empty(row_count) => { let df = DataFrame::new_with_height(row_count as usize, Vec::new()).unwrap(); let df = process_dataframe( df, &source_name, hive_part, - missing_columns.as_ref(), + phase.missing_columns.as_ref(), include_file_paths.as_ref(), file_schema.as_ref(), projection.as_ref(), + row_index.as_ref(), ); let df = match df { Ok(df) => df, @@ -538,8 +595,7 @@ impl SourceNode for MultiScanNode { continue 'phase_loop; } }, - - Ok(rx) => match rx { + SourcePhaseContent::NonEmpty(rx) => match rx { SourceInput::Serial(mut rx) => { while let Ok(rg) = rx.recv().await { let original_source_token = rg.source_token().clone(); @@ -549,10 +605,11 @@ impl SourceNode for MultiScanNode { df, &source_name, hive_part, - missing_columns.as_ref(), + phase.missing_columns.as_ref(), include_file_paths.as_ref(), file_schema.as_ref(), projection.as_ref(), + row_index.as_ref(), ); let df = match df { Ok(df) => df, @@ -600,10 +657,11 @@ impl SourceNode for MultiScanNode { df, &source_name, hive_part, - missing_columns.as_ref(), + phase.missing_columns.as_ref(), include_file_paths.as_ref(), file_schema.as_ref(), projection.as_ref(), + row_index.as_ref(), ); let df = match df { Ok(df) => df, @@ -629,8 +687,15 @@ impl SourceNode for MultiScanNode { } }, }, - }; + } + if let Some(ri) = row_index.as_mut() { + let source_num_rows = phase + .unrestricted_row_count + .unwrap() + .load(Ordering::Relaxed); + ri.offset += source_num_rows; + } current_scan += 1; } break; diff --git a/crates/polars-stream/src/nodes/io_sources/parquet/mod.rs b/crates/polars-stream/src/nodes/io_sources/parquet/mod.rs index 3a109366018e..dc058fefa9b4 100644 --- a/crates/polars-stream/src/nodes/io_sources/parquet/mod.rs +++ b/crates/polars-stream/src/nodes/io_sources/parquet/mod.rs @@ -1,3 +1,4 @@ +use std::sync::atomic::Ordering; use std::sync::Arc; use mem_prefetch_funcs::get_memory_prefetch_func; @@ -5,11 +6,12 @@ use polars_core::config; use polars_core::prelude::ArrowSchema; use polars_core::schema::{Schema, SchemaExt, SchemaRef}; use polars_core::utils::arrow::bitmap::Bitmap; -use polars_error::PolarsResult; +use polars_error::{polars_err, PolarsResult}; use polars_io::cloud::CloudOptions; use polars_io::predicates::ScanIOPredicate; use polars_io::prelude::{FileMetadata, ParquetOptions}; use polars_io::utils::byte_source::DynByteSourceBuilder; +use polars_io::RowIndex; use polars_parquet::read::read_metadata; use polars_parquet::read::schema::infer_schema_with_options; use polars_plan::plans::hive::HivePartitions; @@ -151,7 +153,7 @@ impl SourceNode for ParquetSourceNode { mut output_recv: Receiver, _state: &ExecutionState, join_handles: &mut Vec>>, - _unresistricted_row_count: Option, + unresistricted_row_count: Option>, ) { let (mut send_to, recv_from) = (0..num_pipelines) .map(|_| connector()) @@ -182,6 +184,7 @@ impl SourceNode for ParquetSourceNode { eprintln!("[ParquetSource]: {:?}", &self.config); } + let num_rows = self.first_metadata.as_ref().unwrap().num_rows; self.schema = Some(self.file_info.reader_schema.take().unwrap().unwrap_left()); self.init_projected_arrow_schema(); @@ -190,6 +193,12 @@ impl SourceNode for ParquetSourceNode { let morsel_stream_starter = self.morsel_stream_starter.take().unwrap(); join_handles.push(spawn(TaskPriority::Low, async move { + if let Some(rc) = unresistricted_row_count { + let num_rows = IdxSize::try_from(num_rows) + .map_err(|_| polars_err!(bigidx, ctx = "parquet file", size = num_rows))?; + rc.store(num_rows, Ordering::Relaxed); + } + morsel_stream_starter.send(()).unwrap(); // Every phase we are given a new send port. @@ -261,25 +270,31 @@ impl MultiScanable for ParquetSourceNode { const DOES_PRED_PD: bool = true; const DOES_SLICE_PD: bool = true; - const DOES_ROW_INDEX: bool = true; async fn new( source: ScanSource, options: &Self::ReadOptions, cloud_options: Option<&CloudOptions>, + row_index: Option, ) -> PolarsResult { let source = source.into_sources(); let memslice = source.at(0).to_memslice()?; let file_metadata = read_metadata(&mut std::io::Cursor::new(memslice.as_ref()))?; let arrow_schema = infer_schema_with_options(&file_metadata, &None)?; - let schema = Arc::new(Schema::from_arrow_schema(&arrow_schema)); let arrow_schema = Arc::new(arrow_schema); + let schema = Schema::from_arrow_schema(&arrow_schema); + let schema = Arc::new(schema); + let mut options = options.clone(); options.schema = Some(schema.clone()); - let file_options = FileScanOptions::default(); + let file_options = FileScanOptions { + row_index: row_index.map(|name| RowIndex { name, offset: 0 }), + ..Default::default() + }; + let file_info = FileInfo::new( schema.clone(), Some(rayon::iter::Either::Left(arrow_schema.clone())), @@ -298,11 +313,22 @@ impl MultiScanable for ParquetSourceNode { } fn with_projection(&mut self, projection: Option<&Bitmap>) { - self.file_options.with_columns = projection.map(|p| { - p.true_idx_iter() - .map(|idx| self.file_info.schema.get_at_index(idx).unwrap().0.clone()) - .collect() - }); + if let Some(projection) = projection { + let mut with_columns = Vec::with_capacity( + usize::from(self.file_options.row_index.is_some()) + projection.set_bits(), + ); + + if let Some(ri) = self.file_options.row_index.as_ref() { + with_columns.push(ri.name.clone()); + } + with_columns.extend( + projection + .true_idx_iter() + .map(|idx| self.file_info.schema.get_at_index(idx).unwrap().0.clone()) + .collect::>(), + ); + self.file_options.with_columns = Some(with_columns.into()); + } } fn with_row_restriction(&mut self, row_restriction: Option) { self.predicate = None; @@ -320,16 +346,14 @@ impl MultiScanable for ParquetSourceNode { } } } - fn with_row_index(&mut self, row_index: Option) { - self.row_index = row_index.map(|name| Arc::new((name, AtomicIdxSize::new(0)))); - } - async fn row_count(&mut self) -> PolarsResult { - // @TODO: Overflow - Ok(self.first_metadata.as_ref().unwrap().num_rows as IdxSize) + async fn unrestricted_row_count(&mut self) -> PolarsResult { + let num_rows = self.first_metadata.as_ref().unwrap().num_rows; + IdxSize::try_from(num_rows) + .map_err(|_| polars_err!(bigidx, ctx = "parquet file", size = num_rows)) } - async fn schema(&mut self) -> PolarsResult { + async fn physical_schema(&mut self) -> PolarsResult { Ok(self.file_info.schema.clone()) } } diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index 8d693cfef258..37f38883f3f9 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use parking_lot::Mutex; use polars_core::frame::{DataFrame, UniqueKeepStrategy}; -use polars_core::prelude::{DataType, InitHashMaps, PlHashMap, PlHashSet, PlIndexMap, IDX_DTYPE}; +use polars_core::prelude::{DataType, InitHashMaps, PlHashMap, PlHashSet, PlIndexMap}; use polars_core::schema::{Schema, SchemaExt}; use polars_core::utils::arrow::bitmap::MutableBitmap; use polars_error::PolarsResult; @@ -396,11 +396,7 @@ pub fn lower_ir( || file_options.allow_missing_columns || std::env::var("POLARS_FORCE_MULTISCAN").as_deref() == Ok("1") { - let mut file_schema = file_info.schema.as_ref().clone(); - if let Some(ri) = &file_options.row_index { - // For now, we handle the row index separately. - file_schema.shift_remove(ri.name.as_str()); - } + let file_schema = file_info.schema.clone(); // Create a mask of that indicates which columns are included in the projection. let projection = file_options.with_columns.map(|with_columns| { @@ -417,11 +413,15 @@ pub fn lower_ir( .expect("we should have the column here"); projection.set(idx, true); } + if let Some(c) = file_options.row_index.as_ref() { + let idx = file_schema + .try_index_of(c.name.as_str()) + .expect("we should have the column here"); + projection.set(idx, true); + } projection.freeze() }); - let file_schema = Arc::new(file_schema); - // The schema afterwards only includes the projected columns. let mut schema = if let Some(projection) = projection.as_ref() { Arc::new(file_schema.as_ref().project_select(projection)) @@ -436,26 +436,9 @@ pub fn lower_ir( allow_missing_columns: file_options.allow_missing_columns, include_file_paths: file_options.include_file_paths, projection, + row_index: file_options.row_index, }; - if let Some(row_index) = file_options.row_index { - let mut ri_schema = schema.as_ref().clone(); - ri_schema - .insert_at_index(0, row_index.name.clone(), IDX_DTYPE) - .unwrap(); - let source_node = phys_sm.insert(PhysNode { - output_schema: schema, - kind: node, - }); - let stream = PhysStream::first(source_node); - node = PhysNodeKind::WithRowIndex { - input: stream, - name: row_index.name, - offset: Some(row_index.offset), - }; - schema = Arc::new(ri_schema); - } - let proj_schema = Arc::new(schema.try_project(output_schema.iter_names_cloned())?); let source_node = phys_sm.insert(PhysNode { output_schema: schema, diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index 9073f830a71e..6b763fe25c46 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -6,6 +6,7 @@ use polars_core::prelude::{IdxSize, InitHashMaps, PlHashMap, SortMultipleOptions use polars_core::schema::{Schema, SchemaRef}; use polars_core::utils::arrow::bitmap::Bitmap; use polars_error::PolarsResult; +use polars_io::RowIndex; use polars_ops::frame::JoinArgs; use polars_plan::dsl::JoinTypeOptionsIR; use polars_plan::plans::hive::HivePartitions; @@ -177,7 +178,7 @@ pub enum PhysNodeKind { /// Schema that all files are coerced into. /// - /// - Does **not** include the `row_index`. + /// - Does include the `row_index`. /// - Does include `include_file_paths`. /// - Does include the hive columns. /// @@ -189,6 +190,8 @@ pub enum PhysNodeKind { /// Selection of `file_schema` columns should to be included in the output morsels. projection: Option, + + row_index: Option, }, FileScan { scan_sources: ScanSources, diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index 51ef7f939370..ecb4176c3ea7 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -361,6 +361,7 @@ fn to_graph_rec<'a>( allow_missing_columns, include_file_paths, projection, + row_index, } => match scan_type { #[cfg(feature = "parquet")] polars_plan::plans::FileScan::Parquet { @@ -376,6 +377,7 @@ fn to_graph_rec<'a>( include_file_paths.clone(), file_schema.clone(), projection.clone(), + row_index.clone(), options.clone(), cloud_options.clone(), ), @@ -398,6 +400,7 @@ fn to_graph_rec<'a>( include_file_paths.clone(), file_schema.clone(), projection.clone(), + row_index.clone(), options.clone(), cloud_options.clone(), ), @@ -419,6 +422,7 @@ fn to_graph_rec<'a>( include_file_paths.clone(), file_schema.clone(), projection.clone(), + row_index.clone(), options.clone(), cloud_options.clone(), ), diff --git a/py-polars/tests/unit/io/test_multiscan.py b/py-polars/tests/unit/io/test_multiscan.py index f00d17fc5025..009c38b16c6b 100644 --- a/py-polars/tests/unit/io/test_multiscan.py +++ b/py-polars/tests/unit/io/test_multiscan.py @@ -4,6 +4,7 @@ import pytest import polars as pl +from polars.meta.index_type import get_index_type from polars.testing import assert_frame_equal @@ -146,3 +147,74 @@ def test_multiscan_projection( .select(projection) .collect(new_streaming=True), # type: ignore[call-overload] ) + + +@pytest.mark.parametrize( + ("scan", "write", "ext"), + [ + (pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"), + (pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"), + (pl.scan_csv, pl.DataFrame.write_csv, "csv"), + ], +) +def test_multiscan_row_index( + tmp_path: Path, + scan: Callable[..., pl.LazyFrame], + write: Callable[[pl.DataFrame, Path], Any], + ext: str, +) -> None: + a = pl.DataFrame({"col": [5, 10, 1996]}) + b = pl.DataFrame({"col": [42]}) + c = pl.DataFrame({"col": [13, 37]}) + + write(a, tmp_path / f"a.{ext}") + write(b, tmp_path / f"b.{ext}") + write(c, tmp_path / f"c.{ext}") + + col = pl.concat([a, b, c]).to_series() + g = tmp_path / f"*.{ext}" + + assert_frame_equal( + scan(g, row_index_name="ri").collect(), + pl.DataFrame( + [ + pl.Series("ri", range(6), get_index_type()), + col, + ] + ), + ) + + start = 42 + assert_frame_equal( + scan(g, row_index_name="ri", row_index_offset=start).collect(), + pl.DataFrame( + [ + pl.Series("ri", range(start, start + 6), get_index_type()), + col, + ] + ), + ) + + start = 42 + assert_frame_equal( + scan(g, row_index_name="ri", row_index_offset=start).slice(3, 3).collect(), + pl.DataFrame( + [ + pl.Series("ri", range(start + 3, start + 6), get_index_type()), + col.slice(3, 3), + ] + ), + ) + + start = 42 + assert_frame_equal( + scan(g, row_index_name="ri", row_index_offset=start) + .filter(pl.col("col") < 15) + .collect(), + pl.DataFrame( + [ + pl.Series("ri", [start + 0, start + 1, start + 4], get_index_type()), + pl.Series("col", [5, 10, 13]), + ] + ), + )