From 918bb1ab6f1bd596b784ceec0d35b8752e3d0dd4 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 13 Feb 2025 15:38:37 +0100 Subject: [PATCH] refactor(rust): Improve hash join build sample implementation (#21236) --- Cargo.lock | 1 + crates/polars-stream/Cargo.toml | 1 + .../src/nodes/joins/equi_join.rs | 491 ++++++++++-------- 3 files changed, 266 insertions(+), 227 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ab4a943d7ebd..5461cbb67d74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3471,6 +3471,7 @@ version = "0.46.0" dependencies = [ "atomic-waker", "crossbeam-deque", + "crossbeam-queue", "crossbeam-utils", "futures", "memmap2", diff --git a/crates/polars-stream/Cargo.toml b/crates/polars-stream/Cargo.toml index af3664be0e22..cbf2f9fbe528 100644 --- a/crates/polars-stream/Cargo.toml +++ b/crates/polars-stream/Cargo.toml @@ -11,6 +11,7 @@ description = "Private crate for the streaming execution engine for the Polars D [dependencies] atomic-waker = { workspace = true } crossbeam-deque = { workspace = true } +crossbeam-queue = { workspace = true } crossbeam-utils = { workspace = true } futures = { workspace = true } memmap = { workspace = true } diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index 76c7d37e102f..6e2a699b6878 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -1,6 +1,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, LazyLock}; +use crossbeam_queue::ArrayQueue; use polars_core::prelude::*; use polars_core::schema::{Schema, SchemaExt}; use polars_core::series::IsSorted; @@ -77,6 +78,7 @@ fn compute_payload_selector( .collect() } +/// Fixes names and does coalescing of columns post-join. fn postprocess_join(df: DataFrame, params: &EquiJoinParams) -> DataFrame { if params.args.how == JoinType::Full && params.args.should_coalesce() { // TODO: don't do string-based column lookups for each dataframe, pre-compute coalesce indices. @@ -150,6 +152,77 @@ fn select_payload(df: DataFrame, selector: &[Option]) -> DataFrame { .collect() } +fn estimate_cardinality( + morsels: &[Morsel], + key_selectors: &[StreamExpr], + params: &EquiJoinParams, + state: &ExecutionState, +) -> PolarsResult { + // TODO: parallelize. + let mut sketch = CardinalitySketch::new(); + for morsel in morsels { + let hash_keys = + get_runtime().block_on(select_keys(morsel.df(), key_selectors, params, state))?; + hash_keys.sketch_cardinality(&mut sketch); + } + Ok(sketch.estimate()) +} + +#[expect(clippy::needless_lifetimes)] +fn insert_cached_into_parallel_stream<'s, 'env>( + cached: &'s Option>, + num_pipelines: usize, + recv_port: Option>, + scope: &'s TaskScope<'s, 'env>, + join_handles: &mut Vec>>, +) -> Option>> { + let Some(cached) = cached.as_ref().filter(|c| !c.is_empty()) else { + return recv_port.map(|p| p.parallel()); + }; + + let receivers = if let Some(p) = recv_port { + p.parallel().into_iter().map(Some).collect_vec() + } else { + (0..num_pipelines).map(|_| None).collect_vec() + }; + + let source_token = SourceToken::new(); + let mut out = Vec::new(); + for orig_recv in receivers { + let (mut new_send, new_recv) = connector(); + out.push(new_recv); + let source_token = source_token.clone(); + join_handles.push(scope.spawn_task(TaskPriority::High, async move { + // Act like an InMemorySource node until cached morsels are consumed. + let wait_group = WaitGroup::default(); + loop { + let Some(mut morsel) = cached.pop() else { + break; + }; + morsel.replace_source_token(source_token.clone()); + morsel.set_consume_token(wait_group.token()); + if new_send.send(morsel).await.is_err() { + return Ok(()); + } + wait_group.wait().await; + if source_token.stop_requested() { + return Ok(()); + } + } + + if let Some(mut recv) = orig_recv { + while let Ok(morsel) = recv.recv().await { + if new_send.send(morsel).await.is_err() { + break; + } + } + } + Ok(()) + })); + } + Some(out) +} + #[derive(Default)] struct SampleState { left: Vec, @@ -183,22 +256,169 @@ impl SampleState { this_final_len.store(*len, Ordering::Relaxed); Ok(()) } -} -fn estimate_cardinality( - morsels: &[Morsel], - key_selectors: &[StreamExpr], - params: &EquiJoinParams, - state: &ExecutionState, -) -> PolarsResult { - // TODO: parallelize. - let mut sketch = CardinalitySketch::new(); - for morsel in morsels { - let hash_keys = - get_runtime().block_on(select_keys(morsel.df(), key_selectors, params, state))?; - hash_keys.sketch_cardinality(&mut sketch); + fn try_transition_to_build( + &mut self, + recv: &[PortState], + num_pipelines: usize, + params: &mut EquiJoinParams, + table: &mut Option>, + ) -> PolarsResult> { + let left_saturated = self.left_len >= *SAMPLE_LIMIT; + let right_saturated = self.right_len >= *SAMPLE_LIMIT; + let left_done = recv[0] == PortState::Done || left_saturated; + let right_done = recv[1] == PortState::Done || right_saturated; + #[expect(clippy::nonminimal_bool)] + let stop_sampling = (left_done && right_done) + || (left_done && self.right_len >= LOPSIDED_SAMPLE_FACTOR * self.left_len) + || (right_done && self.left_len >= LOPSIDED_SAMPLE_FACTOR * self.right_len); + if !stop_sampling { + return Ok(None); + } + + if config::verbose() { + eprintln!( + "choosing equi-join build side, sample lengths are: {} vs. {}", + self.left_len, self.right_len + ); + } + + let estimate_cardinalities = || { + let execution_state = ExecutionState::new(); + let left_cardinality = estimate_cardinality( + &self.left, + ¶ms.left_key_selectors, + params, + &execution_state, + )?; + let right_cardinality = estimate_cardinality( + &self.right, + ¶ms.right_key_selectors, + params, + &execution_state, + )?; + let norm_left_factor = self.left_len.min(*SAMPLE_LIMIT) as f64 / self.left_len as f64; + let norm_right_factor = + self.right_len.min(*SAMPLE_LIMIT) as f64 / self.right_len as f64; + let norm_left_cardinality = (left_cardinality as f64 * norm_left_factor) as usize; + let norm_right_cardinality = (right_cardinality as f64 * norm_right_factor) as usize; + if config::verbose() { + eprintln!("estimated cardinalities are: {norm_left_cardinality} vs. {norm_right_cardinality}"); + } + PolarsResult::Ok((norm_left_cardinality, norm_right_cardinality)) + }; + + let left_is_build = match (left_saturated, right_saturated) { + (false, false) => { + if self.left_len * LOPSIDED_SAMPLE_FACTOR < self.right_len + || self.left_len > self.right_len * LOPSIDED_SAMPLE_FACTOR + { + // Don't bother estimating cardinality, just choose smaller as it's highly + // imbalanced. + self.left_len < self.right_len + } else { + let (lc, rc) = estimate_cardinalities()?; + // Let's assume for now that per element building a + // table is 3x more expensive than a probe, with + // unique keys getting an additional 3x factor for + // having to update the hash table in addition to the probe. + let left_build_cost = self.left_len * 3 + 3 * lc; + let left_probe_cost = self.left_len; + let right_build_cost = self.right_len * 3 + 3 * rc; + let right_probe_cost = self.right_len; + left_build_cost + right_probe_cost < left_probe_cost + right_build_cost + } + }, + + // Choose the unsaturated side, the saturated side could be + // arbitrarily big. + (false, true) => true, + (true, false) => false, + + // Estimate cardinality and choose smaller. + (true, true) => { + let (lc, rc) = estimate_cardinalities()?; + lc < rc + }, + }; + + if config::verbose() { + eprintln!( + "build side chosen: {}", + if left_is_build { "left" } else { "right" } + ); + } + + // Transition to building state. + params.left_is_build = Some(left_is_build); + *table = Some(if left_is_build { + new_chunked_idx_table(params.left_key_schema.clone()) + } else { + new_chunked_idx_table(params.right_key_schema.clone()) + }); + + fn make_queue(v: Vec) -> Option> { + if v.is_empty() { + return None; + } + let queue = ArrayQueue::new(v.len()); + for morsel in v { + queue.push(morsel).unwrap(); + } + Some(queue) + } + + let mut sampled_build_morsels = make_queue(core::mem::take(&mut self.left)); + let mut sampled_probe_morsels = make_queue(core::mem::take(&mut self.right)); + if !left_is_build { + core::mem::swap(&mut sampled_build_morsels, &mut sampled_probe_morsels); + } + + let partitioner = HashPartitioner::new(num_pipelines, 0); + let mut build_state = BuildState { + partitions_per_worker: (0..num_pipelines).map(|_| Vec::new()).collect(), + sampled_probe_morsels, + }; + + // Simulate the sample build morsels flowing into the build side. + if sampled_build_morsels.is_some() { + let state = ExecutionState::new(); + crate::async_executor::task_scope(|scope| { + let mut join_handles = Vec::new(); + let receivers = insert_cached_into_parallel_stream( + &sampled_build_morsels, + num_pipelines, + None, + scope, + &mut join_handles, + ) + .unwrap(); + + for (worker_ps, recv) in build_state.partitions_per_worker.iter_mut().zip(receivers) + { + join_handles.push(scope.spawn_task( + TaskPriority::High, + BuildState::partition_and_sink( + recv, + worker_ps, + partitioner.clone(), + params, + &state, + ), + )); + } + + polars_io::pl_async::get_runtime().block_on(async move { + for handle in join_handles { + handle.await?; + } + PolarsResult::Ok(()) + }) + })?; + } + + Ok(Some(build_state)) } - Ok(sketch.estimate()) } #[derive(Default)] @@ -211,7 +431,7 @@ struct BuildPartition { #[derive(Default)] struct BuildState { partitions_per_worker: Vec>, - sampled_probe_morsels: Vec, + sampled_probe_morsels: Option>, } impl BuildState { @@ -358,7 +578,6 @@ impl BuildState { table_per_partition, max_seq_sent: MorselSeq::default(), sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels), - sampled_probe_morsel_idx: AtomicUsize::new(0), } }) } @@ -375,8 +594,7 @@ struct ProbeTable { struct ProbeState { table_per_partition: Vec, max_seq_sent: MorselSeq, - sampled_probe_morsels: Vec, - sampled_probe_morsel_idx: AtomicUsize, + sampled_probe_morsels: Option>, } impl ProbeState { @@ -640,7 +858,11 @@ impl Drop for ProbeState { POOL.install(|| { // Parallel drop as the state might be quite big. self.table_per_partition.par_drain(..).for_each(drop); - self.sampled_probe_morsels.par_drain(..).for_each(drop); + if let Some(morsels) = &self.sampled_probe_morsels { + (0..morsels.len()) + .into_par_iter() + .for_each(|_| drop(morsels.pop())); + } }) } } @@ -856,54 +1078,6 @@ impl EquiJoinNode { } } -// Not ideal - doesn't support stop requests before all cached items are flushed. -#[expect(clippy::needless_lifetimes)] -fn insert_cached_into_parallel_stream<'s, 'env>( - cached: &'s [Morsel], - cached_idx: &'s AtomicUsize, - num_pipelines: usize, - recv_port: Option>, - scope: &'s TaskScope<'s, 'env>, - join_handles: &mut Vec>>, -) -> Option>> { - if cached_idx.load(Ordering::Relaxed) >= cached.len() { - return recv_port.map(|p| p.parallel()); - } - - let receivers = if let Some(p) = recv_port { - p.parallel().into_iter().map(Some).collect_vec() - } else { - (0..num_pipelines).map(|_| None).collect_vec() - }; - - let mut out = Vec::new(); - for orig_recv in receivers { - let (mut new_send, new_recv) = connector(); - out.push(new_recv); - join_handles.push(scope.spawn_task(TaskPriority::High, async move { - loop { - let idx = cached_idx.fetch_add(1, Ordering::Relaxed); - if idx >= cached.len() { - break; - } - if new_send.send(cached[idx].clone()).await.is_err() { - break; - } - } - - if let Some(mut recv) = orig_recv { - while let Ok(morsel) = recv.recv().await { - if new_send.send(morsel).await.is_err() { - break; - } - } - } - Ok(()) - })); - } - Some(out) -} - impl ComputeNode for EquiJoinNode { fn name(&self) -> &str { "equi_join" @@ -923,157 +1097,12 @@ impl ComputeNode for EquiJoinNode { // If we are sampling and both sides are done/filled, transition to building. if let EquiJoinState::Sample(sample_state) = &mut self.state { - let left_saturated = sample_state.left_len >= *SAMPLE_LIMIT; - let right_saturated = sample_state.right_len >= *SAMPLE_LIMIT; - let left_done = recv[0] == PortState::Done || left_saturated; - let right_done = recv[1] == PortState::Done || right_saturated; - #[expect(clippy::nonminimal_bool)] - let stop_sampling = (left_done && right_done) - || (left_done - && sample_state.right_len >= LOPSIDED_SAMPLE_FACTOR * sample_state.left_len) - || (right_done - && sample_state.left_len >= LOPSIDED_SAMPLE_FACTOR * sample_state.right_len); - if stop_sampling { - if config::verbose() { - eprintln!( - "choosing equi-join build side, sample lengths are: {} vs. {}", - sample_state.left_len, sample_state.right_len - ); - } - - let estimate_cardinalities = || { - let execution_state = ExecutionState::new(); - let left_cardinality = estimate_cardinality( - &sample_state.left, - &self.params.left_key_selectors, - &self.params, - &execution_state, - )?; - let right_cardinality = estimate_cardinality( - &sample_state.right, - &self.params.right_key_selectors, - &self.params, - &execution_state, - )?; - let norm_left_factor = sample_state.left_len.min(*SAMPLE_LIMIT) as f64 - / sample_state.left_len as f64; - let norm_right_factor = sample_state.right_len.min(*SAMPLE_LIMIT) as f64 - / sample_state.right_len as f64; - let norm_left_cardinality = - (left_cardinality as f64 * norm_left_factor) as usize; - let norm_right_cardinality = - (right_cardinality as f64 * norm_right_factor) as usize; - if config::verbose() { - eprintln!("estimated cardinalities are: {norm_left_cardinality} vs. {norm_right_cardinality}"); - } - PolarsResult::Ok((norm_left_cardinality, norm_right_cardinality)) - }; - - let left_is_build = match (left_saturated, right_saturated) { - (false, false) => { - if sample_state.left_len * LOPSIDED_SAMPLE_FACTOR < sample_state.right_len - || sample_state.left_len - > sample_state.right_len * LOPSIDED_SAMPLE_FACTOR - { - // Don't bother estimating cardinality, just choose smaller as it's highly - // imbalanced. - sample_state.left_len < sample_state.right_len - } else { - let (lc, rc) = estimate_cardinalities()?; - // Let's assume for now that per element building a - // table is 3x more expensive than a probe, with - // unique keys getting an additional 3x factor for - // having to update the hash table in addition to the probe. - let left_build_cost = sample_state.left_len * 3 + 3 * lc; - let left_probe_cost = sample_state.left_len; - let right_build_cost = sample_state.right_len * 3 + 3 * rc; - let right_probe_cost = sample_state.right_len; - left_build_cost + right_probe_cost < left_probe_cost + right_build_cost - } - }, - - // Choose the unsaturated side, the saturated side could be - // arbitrarily big. - (false, true) => true, - (true, false) => false, - - // Estimate cardinality and choose smaller. - (true, true) => { - let (lc, rc) = estimate_cardinalities()?; - lc < rc - }, - }; - if config::verbose() { - eprintln!( - "build side chosen: {}", - if left_is_build { "left" } else { "right" } - ); - } - - // Transition to building state. - self.params.left_is_build = Some(left_is_build); - self.table = Some(if left_is_build { - new_chunked_idx_table(self.params.left_key_schema.clone()) - } else { - new_chunked_idx_table(self.params.right_key_schema.clone()) - }); - let mut sampled_build_morsels = core::mem::take(&mut sample_state.left); - let mut sampled_probe_morsels = core::mem::take(&mut sample_state.right); - if !left_is_build { - core::mem::swap(&mut sampled_build_morsels, &mut sampled_probe_morsels); - } - - let partitioner = HashPartitioner::new(self.num_pipelines, 0); - let mut build_state = BuildState { - partitions_per_worker: (0..self.num_pipelines).map(|_| Vec::new()).collect(), - sampled_probe_morsels, - }; - - // Simulate the sample build morsels flowing into the build side. - if !sampled_build_morsels.is_empty() { - let state = ExecutionState::new(); - let sampled_build_morsel_idx = AtomicUsize::new(0); - crate::async_executor::task_scope(|scope| { - let mut join_handles = Vec::new(); - let receivers = insert_cached_into_parallel_stream( - &sampled_build_morsels, - &sampled_build_morsel_idx, - self.num_pipelines, - None, - scope, - &mut join_handles, - ) - .unwrap(); - - for (worker_ps, recv) in - build_state.partitions_per_worker.iter_mut().zip(receivers) - { - join_handles.push(scope.spawn_task( - TaskPriority::High, - BuildState::partition_and_sink( - recv, - worker_ps, - partitioner.clone(), - &self.params, - &state, - ), - )); - } - - polars_io::pl_async::get_runtime().block_on(async move { - for handle in join_handles { - handle.await?; - } - PolarsResult::Ok(()) - }) - })?; - } - - POOL.install(|| { - // Parallel drop as the state might be quite big. - sampled_build_morsels.into_par_iter().for_each(drop); - }); - + if let Some(build_state) = sample_state.try_transition_to_build( + recv, + self.num_pipelines, + &mut self.params, + &mut self.table, + )? { self.state = EquiJoinState::Build(build_state); } } @@ -1097,8 +1126,11 @@ impl ComputeNode for EquiJoinNode { // If we are probing and the probe input is done, emit unmatched if // necessary, otherwise we're done. if let EquiJoinState::Probe(probe_state) = &mut self.state { - let samples_consumed = probe_state.sampled_probe_morsel_idx.load(Ordering::Relaxed) - >= probe_state.sampled_probe_morsels.len(); + let samples_consumed = probe_state + .sampled_probe_morsels + .as_ref() + .map(|m| m.is_empty()) + .unwrap_or(true); if samples_consumed && recv[probe_idx] == PortState::Done { if self.params.emit_unmatched_build() { if self.params.preserve_order_build { @@ -1162,12 +1194,15 @@ impl ComputeNode for EquiJoinNode { if recv[probe_idx] != PortState::Done { core::mem::swap(&mut send[0], &mut recv[probe_idx]); } else { - send[0] = if probe_state.sampled_probe_morsel_idx.load(Ordering::Relaxed) - < probe_state.sampled_probe_morsels.len() - { - PortState::Ready - } else { + let samples_consumed = probe_state + .sampled_probe_morsels + .as_ref() + .map(|m| m.is_empty()) + .unwrap_or(true); + send[0] = if samples_consumed { PortState::Done + } else { + PortState::Ready }; } recv[build_idx] = PortState::Done; @@ -1195,7 +1230,10 @@ impl ComputeNode for EquiJoinNode { } fn is_memory_intensive_pipeline_blocker(&self) -> bool { - matches!(self.state, EquiJoinState::Build { .. }) + matches!( + self.state, + EquiJoinState::Sample { .. } | EquiJoinState::Build { .. } + ) } fn spawn<'env, 's>( @@ -1283,7 +1321,6 @@ impl ComputeNode for EquiJoinNode { let senders = send_ports[0].take().unwrap().parallel(); let receivers = insert_cached_into_parallel_stream( &probe_state.sampled_probe_morsels, - &probe_state.sampled_probe_morsel_idx, self.num_pipelines, recv_ports[probe_idx].take(), scope,