diff --git a/optd-mvp/src/expression/logical_expression.rs b/optd-mvp/src/expression/logical_expression.rs index f4797bc..486f1cb 100644 --- a/optd-mvp/src/expression/logical_expression.rs +++ b/optd-mvp/src/expression/logical_expression.rs @@ -5,9 +5,35 @@ //! TODO Figure out if each relation should be in a different submodule. //! TODO This entire file is a WIP. -use crate::{entities::*, memo::GroupId}; +use crate::{entities::logical_expression::Model, memo::GroupId}; use fxhash::hash; use serde::{Deserialize, Serialize}; +use std::fmt::Debug; + +/// An interface defining what an in-memory logical expression representation should be able to do. +pub trait LogicalExpression: From + Into + Clone + Debug { + /// Returns the kind of relation / operator node encoded as an integer. + fn kind(&self) -> i16; + + /// Retrieves the child groups IDs of this logical expression. + fn children(&self) -> Vec; + + /// Computes the fingerprint of this expression, which should generate an integer for equality + /// checks that has a low collision rate. + fn fingerprint(&self) -> i64; + + /// Checks if the current expression is a duplicate of the other expression. + /// + /// Note that this is similar to `Eq` and `PartialEq`, but the implementor should be aware that + /// different expressions can be duplicates of each other without having the exact same data. + fn is_duplicate(&self, other: &Self) -> bool; + + /// Rewrites the expression to use new child groups IDs, where `rewrites` is a slice of tuples + /// representing `(old_group_id, new_group_id)`. + /// + /// TODO: There's definitely a better way to represent this API + fn rewrite(&self, rewrites: &[(GroupId, GroupId)]) -> Self; +} #[derive(Clone, Debug)] pub enum DefaultLogicalExpression { @@ -16,44 +42,32 @@ pub enum DefaultLogicalExpression { Join(Join), } -impl DefaultLogicalExpression { - pub fn kind(&self) -> i16 { +impl LogicalExpression for DefaultLogicalExpression { + fn kind(&self) -> i16 { match self { - DefaultLogicalExpression::Scan(_) => 0, - DefaultLogicalExpression::Filter(_) => 1, - DefaultLogicalExpression::Join(_) => 2, + Self::Scan(_) => 0, + Self::Filter(_) => 1, + Self::Join(_) => 2, } } - /// Calculates the fingerprint of a given expression, but replaces all of the children group IDs - /// with a new group ID if it is listed in the input `rewrites` list. - /// - /// TODO Allow each expression to implement a trait that does this. - pub fn fingerprint_with_rewrite(&self, rewrites: &[(GroupId, GroupId)]) -> i64 { - // Closure that rewrites a group ID if needed. - let rewrite = |x: GroupId| { - if rewrites.is_empty() { - return x; - } - - if let Some(i) = rewrites.iter().position(|(curr, _new)| &x == curr) { - assert_eq!(rewrites[i].0, x); - rewrites[i].1 - } else { - x - } - }; + fn children(&self) -> Vec { + match self { + Self::Scan(_) => vec![], + Self::Filter(filter) => vec![filter.child], + Self::Join(join) => vec![join.left, join.right], + } + } + fn fingerprint(&self) -> i64 { let kind = self.kind() as u16 as usize; let hash = match self { - DefaultLogicalExpression::Scan(scan) => hash(scan.table.as_str()), - DefaultLogicalExpression::Filter(filter) => { - hash(&rewrite(filter.child).0) ^ hash(filter.expression.as_str()) - } - DefaultLogicalExpression::Join(join) => { + Self::Scan(scan) => hash(scan.table.as_str()), + Self::Filter(filter) => hash(&filter.child.0) ^ hash(filter.expression.as_str()), + Self::Join(join) => { // Make sure that there is a difference between `Join(A, B)` and `Join(B, A)`. - hash(&(rewrite(join.left).0 + 1)) - ^ hash(&(rewrite(join.right).0 + 2)) + hash(&(join.left.0 + 1)) + ^ hash(&(join.right.0 + 2)) ^ hash(join.expression.as_str()) } }; @@ -62,10 +76,23 @@ impl DefaultLogicalExpression { ((hash & !0xFFFF) | kind) as i64 } - /// Checks equality between two expressions, with both expression rewriting their child group - /// IDs according to the input `rewrites` list. - pub fn eq_with_rewrite(&self, other: &Self, rewrites: &[(GroupId, GroupId)]) -> bool { - // Closure that rewrites a group ID if needed. + fn is_duplicate(&self, other: &Self) -> bool { + match (self, other) { + (Self::Scan(scan_left), Self::Scan(scan_right)) => scan_left.table == scan_right.table, + (Self::Filter(filter_left), Self::Filter(filter_right)) => { + filter_left.child == filter_right.child + && filter_left.expression == filter_right.expression + } + (Self::Join(join_left), Self::Join(join_right)) => { + join_left.left == join_right.left + && join_left.right == join_right.right + && join_left.expression == join_right.expression + } + _ => false, + } + } + + fn rewrite(&self, rewrites: &[(GroupId, GroupId)]) -> Self { let rewrite = |x: GroupId| { if rewrites.is_empty() { return x; @@ -79,35 +106,17 @@ impl DefaultLogicalExpression { } }; - match (self, other) { - ( - DefaultLogicalExpression::Scan(scan_left), - DefaultLogicalExpression::Scan(scan_right), - ) => scan_left.table == scan_right.table, - ( - DefaultLogicalExpression::Filter(filter_left), - DefaultLogicalExpression::Filter(filter_right), - ) => { - rewrite(filter_left.child) == rewrite(filter_right.child) - && filter_left.expression == filter_right.expression - } - ( - DefaultLogicalExpression::Join(join_left), - DefaultLogicalExpression::Join(join_right), - ) => { - rewrite(join_left.left) == rewrite(join_right.left) - && rewrite(join_left.right) == rewrite(join_right.right) - && join_left.expression == join_right.expression - } - _ => false, - } - } - - pub fn children(&self) -> Vec { match self { - DefaultLogicalExpression::Scan(_) => vec![], - DefaultLogicalExpression::Filter(filter) => vec![filter.child], - DefaultLogicalExpression::Join(join) => vec![join.left, join.right], + Self::Scan(_) => self.clone(), + Self::Filter(filter) => Self::Filter(Filter { + child: rewrite(filter.child), + expression: filter.expression.clone(), + }), + Self::Join(join) => Self::Join(Join { + left: rewrite(join.left), + right: rewrite(join.right), + expression: join.expression.clone(), + }), } } } @@ -130,9 +139,8 @@ pub struct Join { expression: String, } -/// TODO Use a macro. -impl From for DefaultLogicalExpression { - fn from(value: logical_expression::Model) -> Self { +impl From for DefaultLogicalExpression { + fn from(value: Model) -> Self { match value.kind { 0 => Self::Scan( serde_json::from_value(value.data) @@ -151,14 +159,10 @@ impl From for DefaultLogicalExpression { } } -/// TODO Use a macro. -impl From for logical_expression::Model { - fn from(value: DefaultLogicalExpression) -> logical_expression::Model { - fn create_logical_expression( - kind: i16, - data: serde_json::Value, - ) -> logical_expression::Model { - logical_expression::Model { +impl From for Model { + fn from(value: DefaultLogicalExpression) -> Model { + fn create_logical_expression(kind: i16, data: serde_json::Value) -> Model { + Model { id: -1, group_id: -1, kind, diff --git a/optd-mvp/src/expression/physical_expression.rs b/optd-mvp/src/expression/physical_expression.rs index fb8692c..d7f71de 100644 --- a/optd-mvp/src/expression/physical_expression.rs +++ b/optd-mvp/src/expression/physical_expression.rs @@ -2,11 +2,42 @@ //! //! FIXME: All fields are placeholders. //! +//! TODO Remove dead code. //! TODO Figure out if each operator should be in a different submodule. //! TODO This entire file is a WIP. -use crate::{entities::*, memo::GroupId}; +#![allow(dead_code)] + +use crate::{entities::physical_expression::Model, memo::GroupId}; use serde::{Deserialize, Serialize}; +use std::fmt::Debug; + +/// An interface defining what an in-memory physical expression representation should be able to do. +pub trait PhysicalExpression: From + Into + Clone + Debug { + /// Returns the kind of relation / operator node encoded as an integer. + fn kind(&self) -> i16; + + /// Retrieves the child groups IDs of this logical expression. + fn children(&self) -> Vec; +} + +impl PhysicalExpression for DefaultPhysicalExpression { + fn kind(&self) -> i16 { + match self { + Self::TableScan(_) => 0, + Self::Filter(_) => 1, + Self::HashJoin(_) => 2, + } + } + + fn children(&self) -> Vec { + match self { + Self::TableScan(_) => vec![], + Self::Filter(filter) => vec![filter.child], + Self::HashJoin(hash_join) => vec![hash_join.left, hash_join.right], + } + } +} #[derive(Clone, Debug, PartialEq, Eq)] pub enum DefaultPhysicalExpression { @@ -33,9 +64,8 @@ pub struct HashJoin { expression: String, } -/// TODO Use a macro. -impl From for DefaultPhysicalExpression { - fn from(value: physical_expression::Model) -> Self { +impl From for DefaultPhysicalExpression { + fn from(value: Model) -> Self { match value.kind { 0 => Self::TableScan( serde_json::from_value(value.data) @@ -54,14 +84,10 @@ impl From for DefaultPhysicalExpression { } } -/// TODO Use a macro. -impl From for physical_expression::Model { - fn from(value: DefaultPhysicalExpression) -> physical_expression::Model { - fn create_physical_expression( - kind: i16, - data: serde_json::Value, - ) -> physical_expression::Model { - physical_expression::Model { +impl From for Model { + fn from(value: DefaultPhysicalExpression) -> Model { + fn create_physical_expression(kind: i16, data: serde_json::Value) -> Model { + Model { id: -1, group_id: -1, kind, diff --git a/optd-mvp/src/memo/persistent/implementation.rs b/optd-mvp/src/memo/persistent/implementation.rs index 8a60bc9..8d994b3 100644 --- a/optd-mvp/src/memo/persistent/implementation.rs +++ b/optd-mvp/src/memo/persistent/implementation.rs @@ -9,7 +9,7 @@ use super::PersistentMemo; use crate::{ entities::*, - expression::{DefaultLogicalExpression, DefaultPhysicalExpression}, + expression::{LogicalExpression, PhysicalExpression}, memo::{GroupId, GroupStatus, LogicalExpressionId, MemoError, PhysicalExpressionId}, OptimizerResult, DATABASE_URL, }; @@ -18,14 +18,20 @@ use sea_orm::{ entity::{IntoActiveModel, NotSet, Set}, Database, }; -use std::collections::HashSet; +use std::{collections::HashSet, marker::PhantomData}; -impl PersistentMemo { +impl PersistentMemo +where + L: LogicalExpression, + P: PhysicalExpression, +{ /// Creates a new `PersistentMemo` struct by connecting to a database defined at /// [`DATABASE_URL`]. pub async fn new() -> Self { Self { db: Database::connect(DATABASE_URL).await.unwrap(), + _phantom_logical: PhantomData, + _phantom_physical: PhantomData, } } @@ -147,7 +153,7 @@ impl PersistentMemo { pub async fn get_physical_expression( &self, physical_expression_id: PhysicalExpressionId, - ) -> OptimizerResult<(GroupId, DefaultPhysicalExpression)> { + ) -> OptimizerResult<(GroupId, P)> { // Lookup the entity in the database via the unique expression ID. let model = physical_expression::Entity::find_by_id(physical_expression_id.0) .one(&self.db) @@ -167,7 +173,7 @@ impl PersistentMemo { pub async fn get_logical_expression( &self, logical_expression_id: LogicalExpressionId, - ) -> OptimizerResult<(GroupId, DefaultLogicalExpression)> { + ) -> OptimizerResult<(GroupId, L)> { // Lookup the entity in the database via the unique expression ID. let model = logical_expression::Entity::find_by_id(logical_expression_id.0) .one(&self.db) @@ -288,7 +294,7 @@ impl PersistentMemo { pub async fn add_logical_expression_to_group( &self, group_id: GroupId, - logical_expression: DefaultLogicalExpression, + logical_expression: L, children: &[GroupId], ) -> OptimizerResult> { // Check if the expression already exists anywhere in the memo table. @@ -323,7 +329,7 @@ impl PersistentMemo { .await?; // Finally, insert the fingerprint of the logical expression as well. - let new_expr: DefaultLogicalExpression = new_model.into(); + let new_expr: L = new_model.into(); let kind = new_expr.kind(); // In order to calculate a correct fingerprint, we will want to use the IDs of the root @@ -333,7 +339,7 @@ impl PersistentMemo { let root_id = self.get_root_group(child_id).await?; rewrites.push((child_id, root_id)); } - let hash = new_expr.fingerprint_with_rewrite(&rewrites); + let hash = new_expr.rewrite(&rewrites).fingerprint(); let fingerprint = fingerprint::ActiveModel { id: NotSet, @@ -359,7 +365,7 @@ impl PersistentMemo { pub async fn add_physical_expression_to_group( &self, group_id: GroupId, - physical_expression: DefaultPhysicalExpression, + physical_expression: P, children: &[GroupId], ) -> OptimizerResult { // Check if the group actually exists. @@ -399,7 +405,7 @@ impl PersistentMemo { /// expression should _not_ have G2 as a child, and should be replaced with G1. pub async fn is_duplicate_logical_expression( &self, - logical_expression: &DefaultLogicalExpression, + logical_expression: &L, children: &[GroupId], ) -> OptimizerResult> { let model: logical_expression::Model = logical_expression.clone().into(); @@ -415,7 +421,7 @@ impl PersistentMemo { let root_id = self.get_root_group(child_id).await?; rewrites.push((child_id, root_id)); } - let fingerprint = logical_expression.fingerprint_with_rewrite(&rewrites); + let fingerprint = logical_expression.rewrite(&rewrites).fingerprint(); // Filter first by the fingerprint, and then the kind. // FIXME: The kind is already embedded into the fingerprint, so we may not actually need the @@ -447,7 +453,10 @@ impl PersistentMemo { } // Check for an exact match after rewrites. - if logical_expression.eq_with_rewrite(&expr, &rewrites) { + if logical_expression + .rewrite(&rewrites) + .is_duplicate(&expr.rewrite(&rewrites)) + { match_id = Some((group_id, expr_id)); // There should be at most one duplicate expression, so we can break here. @@ -473,7 +482,7 @@ impl PersistentMemo { /// expression, returning brand new IDs for both. pub async fn add_group( &self, - logical_expression: DefaultLogicalExpression, + logical_expression: L, children: &[GroupId], ) -> OptimizerResult> { @@ -517,7 +526,7 @@ impl PersistentMemo { .await?; // Finally, insert the fingerprint of the logical expression as well. - let new_logical_expression: DefaultLogicalExpression = new_expression.into(); + let new_logical_expression: L = new_expression.into(); let kind = new_logical_expression.kind(); // In order to calculate a correct fingerprint, we will want to use the IDs of the root @@ -527,7 +536,7 @@ impl PersistentMemo { let root_id = self.get_root_group(child_id).await?; rewrites.push((child_id, root_id)); } - let hash = new_logical_expression.fingerprint_with_rewrite(&rewrites); + let hash = new_logical_expression.rewrite(&rewrites).fingerprint(); let fingerprint = fingerprint::ActiveModel { id: NotSet, @@ -606,8 +615,8 @@ impl PersistentMemo { seen.insert(expr_id); } - let logical_expression: DefaultLogicalExpression = model.into(); - let hash = logical_expression.fingerprint_with_rewrite(&rewrites); + let logical_expression: L = model.into(); + let hash = logical_expression.rewrite(&rewrites).fingerprint(); let fingerprint = fingerprint::ActiveModel { id: NotSet, diff --git a/optd-mvp/src/memo/persistent/mod.rs b/optd-mvp/src/memo/persistent/mod.rs index ed64fc5..1f5466c 100644 --- a/optd-mvp/src/memo/persistent/mod.rs +++ b/optd-mvp/src/memo/persistent/mod.rs @@ -2,6 +2,7 @@ //! implements the `Memo` trait and supports memo table operations necessary for query optimization. use sea_orm::DatabaseConnection; +use std::marker::PhantomData; #[cfg(test)] mod tests; @@ -9,10 +10,16 @@ mod tests; /// A persistent memo table, backed by a database on disk. /// /// TODO more docs. -pub struct PersistentMemo { +pub struct PersistentMemo { /// This `PersistentMemo` is reliant on the SeaORM [`DatabaseConnection`] that stores all of the /// objects needed for query optimization. db: DatabaseConnection, + + /// Generic marker for a generic logical expression. + _phantom_logical: PhantomData, + + /// Generic marker for a generic physical expression. + _phantom_physical: PhantomData

, } mod implementation; diff --git a/optd-mvp/src/memo/persistent/tests.rs b/optd-mvp/src/memo/persistent/tests.rs index 7493363..12838f6 100644 --- a/optd-mvp/src/memo/persistent/tests.rs +++ b/optd-mvp/src/memo/persistent/tests.rs @@ -4,7 +4,7 @@ use crate::{expression::*, memo::persistent::PersistentMemo}; #[ignore] #[tokio::test] async fn test_simple_logical_duplicates() { - let memo = PersistentMemo::new().await; + let memo = PersistentMemo::::new().await; memo.cleanup().await; let scan = scan("t1".to_string()); @@ -95,7 +95,7 @@ async fn test_simple_add_physical_expression() { #[ignore] #[tokio::test] async fn test_simple_tree() { - let memo = PersistentMemo::new().await; + let memo = PersistentMemo::::new().await; memo.cleanup().await; // Create two scan groups. @@ -145,7 +145,7 @@ async fn test_simple_tree() { #[ignore] #[tokio::test] async fn test_simple_group_link() { - let memo = PersistentMemo::new().await; + let memo = PersistentMemo::::new().await; memo.cleanup().await; // Create two scan groups. @@ -198,10 +198,11 @@ async fn test_simple_group_link() { memo.cleanup().await; } +/// Tests merging groups up a chain. #[ignore] #[tokio::test] async fn test_group_merge_ladder() { - let memo = PersistentMemo::new().await; + let memo = PersistentMemo::::new().await; memo.cleanup().await; // Build up a tree of true filters that should be collapsed into a single table scan. @@ -259,7 +260,7 @@ async fn test_group_merge_ladder() { #[ignore] #[tokio::test] async fn test_group_merge() { - let memo = PersistentMemo::new().await; + let memo = PersistentMemo::::new().await; memo.cleanup().await; // Create a base group. @@ -386,7 +387,7 @@ async fn test_group_merge() { #[ignore] #[tokio::test] async fn test_cascading_merge() { - let memo = PersistentMemo::new().await; + let memo = PersistentMemo::::new().await; memo.cleanup().await; // Create the base groups.