Skip to content

Commit

Permalink
make memo table genric of expression traits
Browse files Browse the repository at this point in the history
This commit replaces the specific expression types with traits that
define the behavior the in-memory represenations of both logical and
physical expressions need to have. Right now, the `PhysicalExpression`
trait does not do that much, but the `LogicalExpression` trait is super
important to how the persistent memo table works.
  • Loading branch information
connortsui20 committed Dec 6, 2024
1 parent e9fba27 commit 2702eb4
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 110 deletions.
152 changes: 78 additions & 74 deletions optd-mvp/src/expression/logical_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,35 @@
//! TODO Figure out if each relation should be in a different submodule.
//! TODO This entire file is a WIP.
use crate::{entities::*, memo::GroupId};
use crate::{entities::logical_expression::Model, memo::GroupId};
use fxhash::hash;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;

/// An interface defining what an in-memory logical expression representation should be able to do.
pub trait LogicalExpression: From<Model> + Into<Model> + Clone + Debug {
/// Returns the kind of relation / operator node encoded as an integer.
fn kind(&self) -> i16;

/// Retrieves the child groups IDs of this logical expression.
fn children(&self) -> Vec<GroupId>;

/// Computes the fingerprint of this expression, which should generate an integer for equality
/// checks that has a low collision rate.
fn fingerprint(&self) -> i64;

/// Checks if the current expression is a duplicate of the other expression.
///
/// Note that this is similar to `Eq` and `PartialEq`, but the implementor should be aware that
/// different expressions can be duplicates of each other without having the exact same data.
fn is_duplicate(&self, other: &Self) -> bool;

/// Rewrites the expression to use new child groups IDs, where `rewrites` is a slice of tuples
/// representing `(old_group_id, new_group_id)`.
///
/// TODO: There's definitely a better way to represent this API
fn rewrite(&self, rewrites: &[(GroupId, GroupId)]) -> Self;
}

#[derive(Clone, Debug)]
pub enum DefaultLogicalExpression {
Expand All @@ -16,44 +42,32 @@ pub enum DefaultLogicalExpression {
Join(Join),
}

impl DefaultLogicalExpression {
pub fn kind(&self) -> i16 {
impl LogicalExpression for DefaultLogicalExpression {
fn kind(&self) -> i16 {
match self {
DefaultLogicalExpression::Scan(_) => 0,
DefaultLogicalExpression::Filter(_) => 1,
DefaultLogicalExpression::Join(_) => 2,
Self::Scan(_) => 0,
Self::Filter(_) => 1,
Self::Join(_) => 2,
}
}

/// Calculates the fingerprint of a given expression, but replaces all of the children group IDs
/// with a new group ID if it is listed in the input `rewrites` list.
///
/// TODO Allow each expression to implement a trait that does this.
pub fn fingerprint_with_rewrite(&self, rewrites: &[(GroupId, GroupId)]) -> i64 {
// Closure that rewrites a group ID if needed.
let rewrite = |x: GroupId| {
if rewrites.is_empty() {
return x;
}

if let Some(i) = rewrites.iter().position(|(curr, _new)| &x == curr) {
assert_eq!(rewrites[i].0, x);
rewrites[i].1
} else {
x
}
};
fn children(&self) -> Vec<GroupId> {
match self {
Self::Scan(_) => vec![],
Self::Filter(filter) => vec![filter.child],
Self::Join(join) => vec![join.left, join.right],
}
}

fn fingerprint(&self) -> i64 {
let kind = self.kind() as u16 as usize;
let hash = match self {
DefaultLogicalExpression::Scan(scan) => hash(scan.table.as_str()),
DefaultLogicalExpression::Filter(filter) => {
hash(&rewrite(filter.child).0) ^ hash(filter.expression.as_str())
}
DefaultLogicalExpression::Join(join) => {
Self::Scan(scan) => hash(scan.table.as_str()),
Self::Filter(filter) => hash(&filter.child.0) ^ hash(filter.expression.as_str()),
Self::Join(join) => {
// Make sure that there is a difference between `Join(A, B)` and `Join(B, A)`.
hash(&(rewrite(join.left).0 + 1))
^ hash(&(rewrite(join.right).0 + 2))
hash(&(join.left.0 + 1))
^ hash(&(join.right.0 + 2))
^ hash(join.expression.as_str())
}
};
Expand All @@ -62,10 +76,23 @@ impl DefaultLogicalExpression {
((hash & !0xFFFF) | kind) as i64
}

/// Checks equality between two expressions, with both expression rewriting their child group
/// IDs according to the input `rewrites` list.
pub fn eq_with_rewrite(&self, other: &Self, rewrites: &[(GroupId, GroupId)]) -> bool {
// Closure that rewrites a group ID if needed.
fn is_duplicate(&self, other: &Self) -> bool {
match (self, other) {
(Self::Scan(scan_left), Self::Scan(scan_right)) => scan_left.table == scan_right.table,
(Self::Filter(filter_left), Self::Filter(filter_right)) => {
filter_left.child == filter_right.child
&& filter_left.expression == filter_right.expression
}
(Self::Join(join_left), Self::Join(join_right)) => {
join_left.left == join_right.left
&& join_left.right == join_right.right
&& join_left.expression == join_right.expression
}
_ => false,
}
}

fn rewrite(&self, rewrites: &[(GroupId, GroupId)]) -> Self {
let rewrite = |x: GroupId| {
if rewrites.is_empty() {
return x;
Expand All @@ -79,35 +106,17 @@ impl DefaultLogicalExpression {
}
};

match (self, other) {
(
DefaultLogicalExpression::Scan(scan_left),
DefaultLogicalExpression::Scan(scan_right),
) => scan_left.table == scan_right.table,
(
DefaultLogicalExpression::Filter(filter_left),
DefaultLogicalExpression::Filter(filter_right),
) => {
rewrite(filter_left.child) == rewrite(filter_right.child)
&& filter_left.expression == filter_right.expression
}
(
DefaultLogicalExpression::Join(join_left),
DefaultLogicalExpression::Join(join_right),
) => {
rewrite(join_left.left) == rewrite(join_right.left)
&& rewrite(join_left.right) == rewrite(join_right.right)
&& join_left.expression == join_right.expression
}
_ => false,
}
}

pub fn children(&self) -> Vec<GroupId> {
match self {
DefaultLogicalExpression::Scan(_) => vec![],
DefaultLogicalExpression::Filter(filter) => vec![filter.child],
DefaultLogicalExpression::Join(join) => vec![join.left, join.right],
Self::Scan(_) => self.clone(),
Self::Filter(filter) => Self::Filter(Filter {
child: rewrite(filter.child),
expression: filter.expression.clone(),
}),
Self::Join(join) => Self::Join(Join {
left: rewrite(join.left),
right: rewrite(join.right),
expression: join.expression.clone(),
}),
}
}
}
Expand All @@ -130,9 +139,8 @@ pub struct Join {
expression: String,
}

/// TODO Use a macro.
impl From<logical_expression::Model> for DefaultLogicalExpression {
fn from(value: logical_expression::Model) -> Self {
impl From<Model> for DefaultLogicalExpression {
fn from(value: Model) -> Self {
match value.kind {
0 => Self::Scan(
serde_json::from_value(value.data)
Expand All @@ -151,14 +159,10 @@ impl From<logical_expression::Model> for DefaultLogicalExpression {
}
}

/// TODO Use a macro.
impl From<DefaultLogicalExpression> for logical_expression::Model {
fn from(value: DefaultLogicalExpression) -> logical_expression::Model {
fn create_logical_expression(
kind: i16,
data: serde_json::Value,
) -> logical_expression::Model {
logical_expression::Model {
impl From<DefaultLogicalExpression> for Model {
fn from(value: DefaultLogicalExpression) -> Model {
fn create_logical_expression(kind: i16, data: serde_json::Value) -> Model {
Model {
id: -1,
group_id: -1,
kind,
Expand Down
50 changes: 38 additions & 12 deletions optd-mvp/src/expression/physical_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,42 @@
//!
//! FIXME: All fields are placeholders.
//!
//! TODO Remove dead code.
//! TODO Figure out if each operator should be in a different submodule.
//! TODO This entire file is a WIP.
use crate::{entities::*, memo::GroupId};
#![allow(dead_code)]

use crate::{entities::physical_expression::Model, memo::GroupId};
use serde::{Deserialize, Serialize};
use std::fmt::Debug;

/// An interface defining what an in-memory physical expression representation should be able to do.
pub trait PhysicalExpression: From<Model> + Into<Model> + Clone + Debug {
/// Returns the kind of relation / operator node encoded as an integer.
fn kind(&self) -> i16;

/// Retrieves the child groups IDs of this logical expression.
fn children(&self) -> Vec<GroupId>;
}

impl PhysicalExpression for DefaultPhysicalExpression {
fn kind(&self) -> i16 {
match self {
Self::TableScan(_) => 0,
Self::Filter(_) => 1,
Self::HashJoin(_) => 2,
}
}

fn children(&self) -> Vec<GroupId> {
match self {
Self::TableScan(_) => vec![],
Self::Filter(filter) => vec![filter.child],
Self::HashJoin(hash_join) => vec![hash_join.left, hash_join.right],
}
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum DefaultPhysicalExpression {
Expand All @@ -33,9 +64,8 @@ pub struct HashJoin {
expression: String,
}

/// TODO Use a macro.
impl From<physical_expression::Model> for DefaultPhysicalExpression {
fn from(value: physical_expression::Model) -> Self {
impl From<Model> for DefaultPhysicalExpression {
fn from(value: Model) -> Self {
match value.kind {
0 => Self::TableScan(
serde_json::from_value(value.data)
Expand All @@ -54,14 +84,10 @@ impl From<physical_expression::Model> for DefaultPhysicalExpression {
}
}

/// TODO Use a macro.
impl From<DefaultPhysicalExpression> for physical_expression::Model {
fn from(value: DefaultPhysicalExpression) -> physical_expression::Model {
fn create_physical_expression(
kind: i16,
data: serde_json::Value,
) -> physical_expression::Model {
physical_expression::Model {
impl From<DefaultPhysicalExpression> for Model {
fn from(value: DefaultPhysicalExpression) -> Model {
fn create_physical_expression(kind: i16, data: serde_json::Value) -> Model {
Model {
id: -1,
group_id: -1,
kind,
Expand Down
Loading

0 comments on commit 2702eb4

Please sign in to comment.