diff --git a/Cargo.toml b/Cargo.toml index 6d153aa..90e7d6f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ permutation = "0.4.1" ordered-float = "3.1.0" tabled = { version = "0.10.0", optional = true } fmt-derive = "0.0.5" +phf = { version = "0.11.1", features = ["macros"] } [features] default = ["terminal-output"] diff --git a/src/codegen.rs b/src/codegen.rs index 5864ba3..a5e94e1 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -1,18 +1,21 @@ //! Intermediate code generation from the AST. use sqlparser::{ - ast::{self, SelectItem, SetExpr, Statement, TableFactor, TableWithJoins}, + ast::{self, FunctionArg, SelectItem, SetExpr, Statement, TableFactor, TableWithJoins}, parser::ParserError, }; +use phf::{phf_set, Set}; + use std::{error::Error, fmt::Display}; use crate::{ - expr::{Expr, ExprError}, - ic::{Instruction, IntermediateCode}, - identifier::IdentifierError, + expr::{agg::AggregateFunction, BinOp, Expr, ExprError, UnOp}, + identifier::{ColumnRef, IdentifierError}, + ir::{Instruction, IntermediateCode}, parser::parse, value::{Value, ValueError}, vm::RegisterIndex, + BoundedString, }; /// Represents either a parser error or a codegen error. @@ -54,11 +57,68 @@ pub fn codegen_str(code: &str) -> Result, ParserOrCodegenE .collect::, CodegenError>>()?) } +const TEMP_COL_NAME_PREFIX: &'static str = "__otter_temp_col"; + +/// Context passed around to any func that needs codegen. +struct CodegenContext { + pub instrs: Vec, + current_reg: RegisterIndex, + last_temp_col_num: usize, +} + +impl CodegenContext { + pub fn new() -> Self { + Self { + instrs: Vec::new(), + current_reg: RegisterIndex::default(), + last_temp_col_num: 0, + } + } + + pub fn get_and_increment_reg(&mut self) -> RegisterIndex { + let reg = self.current_reg; + self.current_reg = self.current_reg.next_index(); + reg + } + + pub fn get_new_temp_col(&mut self) -> BoundedString { + self.last_temp_col_num += 1; + format!("{TEMP_COL_NAME_PREFIX}_{}", self.last_temp_col_num) + .as_str() + .into() + } +} + +impl Default for CodegenContext { + fn default() -> Self { + Self::new() + } +} + +static AGGREGATE_FUNCTIONS: Set<&'static str> = phf_set! { + "count", + "max", + "min", + "sum", +}; + +fn extract_alias_from_project( + projection: &SelectItem, +) -> Result, CodegenError> { + match projection { + SelectItem::UnnamedExpr(_) => Ok(None), + SelectItem::ExprWithAlias { alias, .. } => Ok(Some(alias.value.as_str().into())), + SelectItem::QualifiedWildcard(name) => Err(CodegenError::UnsupportedStatementForm( + "Qualified wildcards are not supported yet", + name.to_string(), + )), + SelectItem::Wildcard => Ok(None), + } +} + /// Generates intermediate code from the AST. pub fn codegen_ast(ast: &Statement) -> Result { - let mut instrs = Vec::::new(); - - let mut current_reg = RegisterIndex::default(); + let mut ctx = CodegenContext::default(); match ast { Statement::CreateTable { @@ -87,35 +147,33 @@ pub fn codegen_ast(ast: &Statement) -> Result { collation: _, on_commit: _, } => { - let table_reg_index = current_reg; - instrs.push(Instruction::Empty { + let table_reg_index = ctx.get_and_increment_reg(); + ctx.instrs.push(Instruction::Empty { index: table_reg_index, }); - current_reg = current_reg.next_index(); - let col_reg_index = current_reg; - current_reg = current_reg.next_index(); + let col_reg_index = ctx.get_and_increment_reg(); for col in columns { - instrs.push(Instruction::ColumnDef { + ctx.instrs.push(Instruction::ColumnDef { index: col_reg_index, name: col.name.value.as_str().into(), data_type: col.data_type.clone(), }); for option in col.options.iter() { - instrs.push(Instruction::AddColumnOption { + ctx.instrs.push(Instruction::AddColumnOption { index: col_reg_index, option: option.clone(), }); } - instrs.push(Instruction::AddColumn { + ctx.instrs.push(Instruction::AddColumn { table_reg_index, col_index: col_reg_index, }); } - instrs.push(Instruction::NewTable { + ctx.instrs.push(Instruction::NewTable { index: table_reg_index, name: name.0.clone().try_into()?, exists_ok: *if_not_exists, @@ -134,22 +192,20 @@ pub fn codegen_ast(ast: &Statement) -> Result { table: _, on: _, } => { - let table_reg_index = current_reg; - instrs.push(Instruction::Source { + let table_reg_index = ctx.get_and_increment_reg(); + ctx.instrs.push(Instruction::Source { index: table_reg_index, name: table_name.0.clone().try_into()?, }); - current_reg = current_reg.next_index(); - let insert_reg_index = current_reg; - instrs.push(Instruction::InsertDef { + let insert_reg_index = ctx.get_and_increment_reg(); + ctx.instrs.push(Instruction::InsertDef { table_reg_index, index: insert_reg_index, }); - current_reg = current_reg.next_index(); for col in columns { - instrs.push(Instruction::ColumnInsertDef { + ctx.instrs.push(Instruction::ColumnInsertDef { insert_index: insert_reg_index, col_name: col.value.as_str().into(), }) @@ -162,19 +218,22 @@ pub fn codegen_ast(ast: &Statement) -> Result { .. } => { for row in values.0.clone() { - let row_reg = current_reg; - current_reg = current_reg.next_index(); + let row_reg = ctx.get_and_increment_reg(); - instrs.push(Instruction::RowDef { + ctx.instrs.push(Instruction::RowDef { insert_index: insert_reg_index, row_index: row_reg, }); - for value in row { - instrs.push(Instruction::AddValue { + for value_ast in row { + let value = codegen_expr(value_ast.clone(), &mut ctx)?.get_non_agg( + "Aggregate expressions are not supported in values", + value_ast, + )?; + ctx.instrs.push(Instruction::AddValue { row_index: row_reg, - expr: value.try_into()?, - }); + expr: value, + }) } } Ok(()) @@ -185,7 +244,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { )), }?; - instrs.push(Instruction::Insert { + ctx.instrs.push(Instruction::Insert { index: insert_reg_index, }); @@ -193,8 +252,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { } Statement::Query(query) => { // TODO: support CTEs - let mut table_reg_index = current_reg; - current_reg = current_reg.next_index(); + let mut table_reg_index = ctx.get_and_increment_reg(); match &query.body { SetExpr::Select(select) => { @@ -215,7 +273,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { alias: _, args: _, with_hints: _, - } => instrs.push(Instruction::Source { + } => ctx.instrs.push(Instruction::Source { index: table_reg_index, name: name.0.clone().try_into()?, }), @@ -249,7 +307,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { } } } - &[] => instrs.push(Instruction::NonExistent { + &[] => ctx.instrs.push(Instruction::NonExistent { index: table_reg_index, }), _ => { @@ -261,62 +319,162 @@ pub fn codegen_ast(ast: &Statement) -> Result { } } - if let Some(expr) = select.selection.clone() { - instrs.push(Instruction::Filter { + if let Some(expr_ast) = select.selection.clone() { + let expr = codegen_expr(expr_ast.clone(), &mut ctx)?.get_non_agg( + "Aggregate expressions are not supported in WHERE clause. Use HAVING clause instead", + expr_ast, + )?; + ctx.instrs.push(Instruction::Filter { index: table_reg_index, - expr: expr.try_into()?, + expr, }) } - for group_by in select.group_by.clone() { - instrs.push(Instruction::GroupBy { - index: table_reg_index, - expr: group_by.try_into()?, + let inter_exprs = select + .projection + .iter() + .cloned() + .map(|projection| codegen_selectitem(&projection, &mut ctx)) + .collect::, _>>()?; + + // if there are groupby + aggregations, we project all operations within an + // aggregation to another table first. for example, `SUM(col * col)` would be + // evaluated as `Project (col * col)` into `%2` and then apply the group by on + // `%2`. + let pre_grouped_reg_index = table_reg_index; + // let mut agg_intermediate_cols = Vec::new(); + // if !select.projection.is_empty() { + if !inter_exprs.is_empty() { + let pre_grouped_inter_reg_index = ctx.get_and_increment_reg(); + ctx.instrs.push(Instruction::Empty { + index: pre_grouped_inter_reg_index, }); + + table_reg_index = pre_grouped_inter_reg_index; + + for inter_expr in inter_exprs.iter() { + match inter_expr { + IntermediateExpr::Agg(agg) => { + for (expr, projected_col_name) in agg.pre_agg.clone() { + ctx.instrs.push(Instruction::Project { + input: pre_grouped_reg_index, + output: pre_grouped_inter_reg_index, + expr, + alias: Some(projected_col_name), + }); + } + } + IntermediateExpr::NonAgg(_) => {} + } + } } - if let Some(expr) = select.having.clone() { - instrs.push(Instruction::Filter { - index: table_reg_index, - expr: expr.try_into()?, - }) + for group_by_ast in select.group_by.clone() { + let group_by = codegen_expr(group_by_ast.clone(), &mut ctx)?.get_non_agg( + "Aggregate expressions are not supported in the GROUP BY clause", + group_by_ast, + )?; + let grouped_reg_index = ctx.get_and_increment_reg(); + ctx.instrs.push(Instruction::Empty { + index: grouped_reg_index, + }); + ctx.instrs.push(Instruction::GroupBy { + input: table_reg_index, + output: grouped_reg_index, + expr: group_by, + }); + table_reg_index = grouped_reg_index } - if !select.projection.is_empty() { - let original_table_reg_index = table_reg_index; - table_reg_index = current_reg; - current_reg = current_reg.next_index(); + // this is only for aggregations. + // aggs are applied on the grouped table created by the `GroupBy` instructions + // generated above. + if !inter_exprs.is_empty() { + let has_aggs = inter_exprs + .iter() + .any(|ie| matches!(ie, IntermediateExpr::Agg(_))); - instrs.push(Instruction::Empty { - index: table_reg_index, - }); + if has_aggs { + // codegen the aggregations themselves to an intermediate table + let original_table_reg_index = table_reg_index; + table_reg_index = ctx.get_and_increment_reg(); + + ctx.instrs.push(Instruction::Empty { + index: table_reg_index, + }); - for projection in select.projection.clone() { - instrs.push(Instruction::Project { - input: original_table_reg_index, - output: table_reg_index, - expr: match projection { - SelectItem::UnnamedExpr(ref expr) => expr.clone().try_into()?, - SelectItem::ExprWithAlias { ref expr, .. } => { - expr.clone().try_into()? + for inter_expr in &inter_exprs { + match inter_expr { + IntermediateExpr::Agg(agg) => { + for (agg_fn, col_name, alias) in &agg.agg { + ctx.instrs.push(Instruction::Aggregate { + input: original_table_reg_index, + output: table_reg_index, + func: agg_fn.clone(), + col_name: *col_name, + alias: Some(*alias), + }); + } } - SelectItem::QualifiedWildcard(_) => Expr::Wildcard, - SelectItem::Wildcard => Expr::Wildcard, - }, - alias: match projection { - SelectItem::UnnamedExpr(_) => None, - SelectItem::ExprWithAlias { alias, .. } => { - Some(alias.value.as_str().into()) + IntermediateExpr::NonAgg(_) => {} + } + } + + let last_grouped_reg_index = table_reg_index; + table_reg_index = ctx.get_and_increment_reg(); + + ctx.instrs.push(Instruction::Empty { + index: table_reg_index, + }); + + for (projection, inter_expr) in + select.projection.iter().zip(inter_exprs.iter()) + { + let alias = extract_alias_from_project(&projection)?; + + match inter_expr { + IntermediateExpr::Agg(agg) => { + for expr in agg.post_agg.clone() { + ctx.instrs.push(Instruction::Project { + input: last_grouped_reg_index, + output: table_reg_index, + expr, + alias, + }) + } } - SelectItem::QualifiedWildcard(name) => { - return Err(CodegenError::UnsupportedStatementForm( - "Qualified wildcards are not supported yet", - name.to_string(), - )) + IntermediateExpr::NonAgg(expr) => { + let alias = extract_alias_from_project(&projection)?; + let projection = Instruction::Project { + input: pre_grouped_reg_index, + output: table_reg_index, + expr: expr.clone(), + alias, + }; + ctx.instrs.push(projection) } - SelectItem::Wildcard => None, - }, - }) + } + } + } else { + for (projection, inter_expr) in + select.projection.iter().zip(inter_exprs.iter()) + { + match inter_expr { + IntermediateExpr::NonAgg(expr) => { + let alias = extract_alias_from_project(&projection)?; + let projection = Instruction::Project { + input: pre_grouped_reg_index, + output: table_reg_index, + expr: expr.clone(), + alias, + }; + ctx.instrs.push(projection) + } + IntermediateExpr::Agg(_) => { + unreachable!("already checked for aggregates") + } + } + } } if select.distinct { @@ -326,11 +484,30 @@ pub fn codegen_ast(ast: &Statement) -> Result { )); } } + + if let Some(expr_ast) = select.having.clone() { + let expr = codegen_expr(expr_ast.clone(), &mut ctx)?.get_non_agg( + concat!( + "HAVING clause does not support inline aggregations.", + " Select the expression `AS some_col_name` ", + "and then use `HAVING` on `some_col_name`." + ), + expr_ast, + )?; + ctx.instrs.push(Instruction::Filter { + index: table_reg_index, + expr, + }) + } } SetExpr::Values(exprs) => { if exprs.0.len() == 1 && exprs.0[0].len() == 1 { - let expr: Expr = exprs.0[0][0].clone().try_into()?; - instrs.push(Instruction::Expr { + let expr_ast = exprs.0[0][0].clone(); + let expr = codegen_expr(expr_ast.clone(), &mut ctx)?.get_non_agg( + "Aggregate expressions are not supported in values", + expr_ast, + )?; + ctx.instrs.push(Instruction::Expr { index: table_reg_index, expr, }); @@ -376,9 +553,13 @@ pub fn codegen_ast(ast: &Statement) -> Result { }; for order_by in query.order_by.clone() { - instrs.push(Instruction::Order { + let order_by_expr = codegen_expr(order_by.expr.clone(), &mut ctx)?.get_non_agg( + "Aggregate expressions are not supported in ORDER BY", + order_by.expr, + )?; + ctx.instrs.push(Instruction::Order { index: table_reg_index, - expr: order_by.expr.try_into()?, + expr: order_by_expr, ascending: order_by.asc.unwrap_or(true), }); // TODO: support NULLS FIRST/NULLS LAST @@ -387,7 +568,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { if let Some(limit) = query.limit.clone() { if let ast::Expr::Value(val) = limit.clone() { if let Value::Int64(limit) = val.clone().try_into()? { - instrs.push(Instruction::Limit { + ctx.instrs.push(Instruction::Limit { index: table_reg_index, limit: limit as u64, }); @@ -407,7 +588,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { } } - instrs.push(Instruction::Return { + ctx.instrs.push(Instruction::Return { index: table_reg_index, }); @@ -417,7 +598,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { schema_name, if_not_exists, } => { - instrs.push(Instruction::NewSchema { + ctx.instrs.push(Instruction::NewSchema { schema_name: schema_name.0.clone().try_into()?, exists_ok: *if_not_exists, }); @@ -426,7 +607,298 @@ pub fn codegen_ast(ast: &Statement) -> Result { _ => Err(CodegenError::UnsupportedStatement(ast.to_string())), }?; - Ok(IntermediateCode { instrs }) + Ok(IntermediateCode { instrs: ctx.instrs }) +} + +#[derive(Debug, Clone, PartialEq)] +struct IntermediateExprAgg { + pub pre_agg: Vec<(Expr, BoundedString)>, + pub agg: Vec<(AggregateFunction, BoundedString, BoundedString)>, + pub post_agg: Vec, + last_alias: Option, + last_expr: (Expr, BoundedString), +} + +#[derive(Debug, Clone)] +enum IntermediateExpr { + NonAgg(Expr), + Agg(IntermediateExprAgg), +} + +impl IntermediateExpr { + pub fn new_non_agg(expr: Expr) -> Self { + Self::NonAgg(expr) + } + + /// The last expression that was generated. + pub fn last_expr(&self) -> &Expr { + match self { + Self::NonAgg(e) => e, + Self::Agg(agg) => &agg.last_expr.0, + } + } + + pub fn combine(self, new: IntermediateExpr) -> Self { + match self { + Self::NonAgg(_) => new, + Self::Agg(mut sel) => match new { + Self::NonAgg(new) => { + // TODO: last_expr may need updating here? + if sel.post_agg.len() <= 1 { + sel.post_agg = vec![new]; + } else { + *sel.post_agg.last_mut().unwrap() = new; + } + Self::Agg(sel) + } + Self::Agg(new) => { + // TODO: last_expr may need updating here? + sel.pre_agg.extend_from_slice(&new.pre_agg); + sel.agg.extend_from_slice(&new.agg); + sel.post_agg.extend_from_slice(&new.post_agg); + Self::Agg(sel) + } + }, + } + } + + pub fn get_non_agg( + self, + err_reason: &'static str, + expr_ast: ast::Expr, + ) -> Result { + match self { + IntermediateExpr::Agg(_) => Err(ExprError::Expr { + reason: err_reason, + expr: expr_ast, + }), + IntermediateExpr::NonAgg(e) => Ok(e), + } + } +} + +fn codegen_selectitem( + projection: &SelectItem, + ctx: &mut CodegenContext, +) -> Result { + match projection { + SelectItem::UnnamedExpr(ref expr) => codegen_expr(expr.clone(), ctx), + SelectItem::ExprWithAlias { ref expr, .. } => codegen_expr(expr.clone(), ctx), + SelectItem::QualifiedWildcard(_) => Ok(IntermediateExpr::new_non_agg(Expr::Wildcard)), + SelectItem::Wildcard => Ok(IntermediateExpr::new_non_agg(Expr::Wildcard)), + } +} + +fn codegen_expr( + expr_ast: ast::Expr, + ctx: &mut CodegenContext, +) -> Result { + match expr_ast { + ast::Expr::Identifier(i) => Ok(IntermediateExpr::new_non_agg(Expr::ColumnRef( + vec![i].try_into()?, + ))), + ast::Expr::CompoundIdentifier(i) => Ok(IntermediateExpr::new_non_agg(Expr::ColumnRef( + i.try_into()?, + ))), + ast::Expr::IsFalse(e) => { + let inner = codegen_expr(*e, ctx)?; + let new_expr = Expr::Unary { + op: UnOp::IsFalse, + operand: Box::new(inner.last_expr().clone()), + }; + Ok(inner.combine(IntermediateExpr::new_non_agg(new_expr))) + } + ast::Expr::IsTrue(e) => { + let inner = codegen_expr(*e, ctx)?; + let new_expr = Expr::Unary { + op: UnOp::IsTrue, + operand: Box::new(inner.last_expr().clone()), + }; + Ok(inner.combine(IntermediateExpr::new_non_agg(new_expr))) + } + ast::Expr::IsNull(e) => { + let inner = codegen_expr(*e, ctx)?; + let new_expr = Expr::Unary { + op: UnOp::IsNull, + operand: Box::new(inner.last_expr().clone()), + }; + Ok(inner.combine(IntermediateExpr::new_non_agg(new_expr))) + } + ast::Expr::IsNotNull(e) => { + let inner = codegen_expr(*e, ctx)?; + let new_expr = Expr::Unary { + op: UnOp::IsNotNull, + operand: Box::new(inner.last_expr().clone()), + }; + Ok(inner.combine(IntermediateExpr::new_non_agg(new_expr))) + } + ast::Expr::Between { + expr, + negated, + low, + high, + } => { + let expr_gen = codegen_expr(*expr, ctx)?; + let expr: Box = Box::new(expr_gen.last_expr().clone()); + + let left_gen = codegen_expr(*low, ctx)?; + let left = Box::new(left_gen.last_expr().clone()); + + let right_gen = codegen_expr(*high, ctx)?; + let right = Box::new(right_gen.last_expr().clone()); + + let between_gen = IntermediateExpr::new_non_agg(Expr::Binary { + left: Box::new(Expr::Binary { + left, + op: BinOp::LessThanOrEqual, + right: expr.clone(), + }), + op: BinOp::And, + right: Box::new(Expr::Binary { + left: expr, + op: BinOp::LessThanOrEqual, + right, + }), + }); + let between_last_expr = between_gen.last_expr().clone(); + + let between = expr_gen + .combine(left_gen) + .combine(right_gen) + .combine(between_gen); + + if negated { + Ok(between.combine(IntermediateExpr::new_non_agg(Expr::Unary { + op: UnOp::Not, + operand: Box::new(between_last_expr), + }))) + } else { + Ok(between) + } + } + ast::Expr::BinaryOp { left, op, right } => { + let left = codegen_expr(*left, ctx)?; + let left_operand = Box::new(left.last_expr().clone()); + let right = codegen_expr(*right, ctx)?; + let right_operand = Box::new(right.last_expr().clone()); + + let binary_expr = Expr::Binary { + left: left_operand, + op: op.try_into()?, + right: right_operand, + }; + + Ok(left + .combine(right) + .combine(IntermediateExpr::new_non_agg(binary_expr))) + } + ast::Expr::UnaryOp { op, expr } => { + let inner = codegen_expr(*expr, ctx)?; + let new_expr = Expr::Unary { + op: op.try_into()?, + operand: Box::new(inner.last_expr().clone()), + }; + Ok(inner.combine(IntermediateExpr::new_non_agg(new_expr))) + } + ast::Expr::Value(v) => Ok(IntermediateExpr::new_non_agg(Expr::Value(v.try_into()?))), + ast::Expr::Function(ref f) => { + let fn_name = f.name.to_string(); + let args = f + .args + .iter() + .map(|arg| { + let ie = codegen_fn_arg(&expr_ast, arg, ctx)?; + let last_expr = ie.last_expr().clone(); + Ok::<(IntermediateExpr, Expr), ExprError>((ie, last_expr)) + }) + .collect::, _>>()?; + if is_fn_name_aggregate(&fn_name.to_lowercase()) { + if args.len() > 1 { + Err(ExprError::Expr { + reason: "Aggregates with more than one arguments are not supported yet.", + expr: expr_ast, + }) + } else { + let args = args + .into_iter() + .map(|a| match a { + (IntermediateExpr::Agg(_), _) => Err(ExprError::Expr { + reason: "Aggregates within aggregates are not supported yet", + expr: expr_ast.clone(), + }), + (IntermediateExpr::NonAgg(e), _) => Ok((e, ctx.get_new_temp_col())), + }) + .collect::, _>>()?; + let agg_result_col = ctx.get_new_temp_col(); + let agg_col_res = Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: agg_result_col, + }); + let agg = vec![( + AggregateFunction::from_name(&fn_name.to_lowercase())?, + args[0].1, + agg_result_col, + )]; + Ok(IntermediateExpr::Agg(IntermediateExprAgg { + pre_agg: args, + agg, + post_agg: vec![agg_col_res.clone()], + last_alias: Some(agg_result_col), + last_expr: (agg_col_res, agg_result_col), + })) + } + } else { + if args.is_empty() { + Ok(IntermediateExpr::new_non_agg(Expr::Function { + name: fn_name.as_str().into(), + args: args.iter().map(|ie| ie.1.clone()).collect(), + })) + } else { + let start = args[0].0.clone(); + let combined = args + .iter() + .skip(1) + .fold(start, |acc, ie| acc.combine(ie.0.clone())); + Ok( + combined.combine(IntermediateExpr::new_non_agg(Expr::Function { + name: fn_name.as_str().into(), + args: args.iter().map(|ie| ie.1.clone()).collect(), + })), + ) + } + } + } + _ => Err(ExprError::Expr { + reason: "Unsupported expression", + expr: expr_ast, + }), + } +} + +fn is_fn_name_aggregate(fn_name: &str) -> bool { + AGGREGATE_FUNCTIONS.contains(fn_name) +} + +fn codegen_fn_arg( + expr_ast: &ast::Expr, + arg: &FunctionArg, + ctx: &mut CodegenContext, +) -> Result { + match arg { + ast::FunctionArg::Unnamed(arg_expr) => match arg_expr { + ast::FunctionArgExpr::Expr(e) => Ok(codegen_expr(e.clone(), ctx)?), + ast::FunctionArgExpr::Wildcard => Ok(IntermediateExpr::new_non_agg(Expr::Wildcard)), + ast::FunctionArgExpr::QualifiedWildcard(_) => Err(ExprError::Expr { + reason: "Qualified wildcards are not supported yet", + expr: expr_ast.clone(), + }), + }, + ast::FunctionArg::Named { .. } => Err(ExprError::Expr { + reason: "Named function arguments are not supported", + expr: expr_ast.clone(), + }), + } } /// Error while generating an intermediate code from the AST. @@ -474,16 +946,16 @@ impl From for CodegenError { impl Error for CodegenError {} #[cfg(test)] -mod tests { +mod codegen_tests { use sqlparser::ast::{ColumnOption, ColumnOptionDef, DataType}; use pretty_assertions::assert_eq; use crate::{ codegen::codegen_ast, - expr::{BinOp, Expr}, - ic::Instruction, + expr::{agg::AggregateFunction, BinOp, Expr}, identifier::{ColumnRef, SchemaRef, TableRef}, + ir::Instruction, parser::parse, value::Value, vm::RegisterIndex, @@ -1113,11 +1585,10 @@ mod tests { FROM table1 WHERE col1 = 1 GROUP BY col2 - HAVING MAX(col3) > 10 + HAVING max_col3 > 10 ", |instrs| { assert_eq!( - instrs, &[ Instruction::Source { index: RegisterIndex::default(), @@ -1138,35 +1609,61 @@ mod tests { right: Box::new(Expr::Value(Value::Int64(1))) }, }, + Instruction::Empty { + index: RegisterIndex::default().next_index() + }, + Instruction::Project { + input: RegisterIndex::default(), + output: RegisterIndex::default().next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col3".into(), + }), + alias: Some("__otter_temp_col_1".into()) + }, + Instruction::Empty { + index: RegisterIndex::default().next_index().next_index() + }, Instruction::GroupBy { - index: RegisterIndex::default(), + input: RegisterIndex::default().next_index(), + output: RegisterIndex::default().next_index().next_index(), expr: Expr::ColumnRef(ColumnRef { schema_name: None, table_name: None, col_name: "col2".into(), }) }, - Instruction::Filter { - index: RegisterIndex::default(), - expr: Expr::Binary { - left: Box::new(Expr::Function { - name: "MAX".into(), - args: vec![Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: None, - col_name: "col3".into(), - })] - }), - op: BinOp::GreaterThan, - right: Box::new(Expr::Value(Value::Int64(10))) - }, + Instruction::Empty { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + }, + Instruction::Aggregate { + input: RegisterIndex::default().next_index().next_index(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index(), + func: AggregateFunction::Max, + col_name: "__otter_temp_col_1".into(), + alias: Some("__otter_temp_col_2".into()), }, Instruction::Empty { - index: RegisterIndex::default().next_index() + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index() }, Instruction::Project { input: RegisterIndex::default(), - output: RegisterIndex::default().next_index(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), expr: Expr::ColumnRef(ColumnRef { schema_name: None, table_name: None, @@ -1175,24 +1672,872 @@ mod tests { alias: None }, Instruction::Project { - input: RegisterIndex::default(), - output: RegisterIndex::default().next_index(), - expr: Expr::Function { - name: "MAX".into(), - args: vec![Expr::ColumnRef(ColumnRef { + input: RegisterIndex::default() + .next_index() + .next_index() + .next_index(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "__otter_temp_col_2".into(), + }), + alias: Some("max_col3".into()) + }, + Instruction::Filter { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), + expr: Expr::Binary { + left: Box::new(Expr::ColumnRef(ColumnRef { schema_name: None, table_name: None, - col_name: "col3".into(), - })] + col_name: "max_col3".into(), + })), + op: BinOp::GreaterThan, + right: Box::new(Expr::Value(Value::Int64(10))) }, - alias: Some("max_col3".into()) }, Instruction::Return { - index: RegisterIndex::default().next_index(), + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), } - ] + ], + instrs, ) }, ); + + check_single_statement( + "SELECT col2, MAX(col3) + 1 AS max_col3 + FROM table1 + GROUP BY col2 + ", + |instrs| { + assert_eq!( + &[ + Instruction::Source { + index: RegisterIndex::default(), + name: TableRef { + schema_name: None, + table_name: "table1".into() + } + }, + Instruction::Empty { + index: RegisterIndex::default().next_index() + }, + Instruction::Project { + input: RegisterIndex::default(), + output: RegisterIndex::default().next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col3".into(), + }), + alias: Some("__otter_temp_col_1".into()) + }, + Instruction::Empty { + index: RegisterIndex::default().next_index().next_index() + }, + Instruction::GroupBy { + input: RegisterIndex::default().next_index(), + output: RegisterIndex::default().next_index().next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col2".into(), + }) + }, + Instruction::Empty { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + }, + Instruction::Aggregate { + input: RegisterIndex::default().next_index().next_index(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index(), + func: AggregateFunction::Max, + col_name: "__otter_temp_col_1".into(), + alias: Some("__otter_temp_col_2".into()), + }, + Instruction::Empty { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index() + }, + Instruction::Project { + input: RegisterIndex::default(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col2".into(), + }), + alias: None + }, + Instruction::Project { + input: RegisterIndex::default() + .next_index() + .next_index() + .next_index(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), + expr: Expr::Binary { + left: Box::new(Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "__otter_temp_col_2".into(), + })), + op: BinOp::Plus, + right: Box::new(Expr::Value(Value::Int64(1))), + }, + alias: Some("max_col3".into()), + }, + Instruction::Return { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), + } + ], + instrs, + ) + }, + ); + + check_single_statement( + "SELECT col2, col3, SUM(col4 * col4) AS sos + FROM table1 + WHERE col1 = 1 + GROUP BY col2, col3 + ", + |instrs| { + assert_eq!( + &[ + Instruction::Source { + index: RegisterIndex::default(), + name: TableRef { + schema_name: None, + table_name: "table1".into() + } + }, + Instruction::Filter { + index: RegisterIndex::default(), + expr: Expr::Binary { + left: Box::new(Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col1".into(), + },)), + op: BinOp::Equal, + right: Box::new(Expr::Value(Value::Int64(1))) + }, + }, + Instruction::Empty { + index: RegisterIndex::default().next_index() + }, + Instruction::Project { + input: RegisterIndex::default(), + output: RegisterIndex::default().next_index(), + expr: Expr::Binary { + left: Box::new(Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col4".into(), + })), + op: BinOp::Multiply, + right: Box::new(Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col4".into(), + })) + }, + alias: Some("__otter_temp_col_1".into()) + }, + Instruction::Empty { + index: RegisterIndex::default().next_index().next_index() + }, + Instruction::GroupBy { + input: RegisterIndex::default().next_index(), + output: RegisterIndex::default().next_index().next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col2".into(), + }) + }, + Instruction::Empty { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + }, + Instruction::GroupBy { + input: RegisterIndex::default().next_index().next_index(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col3".into(), + }) + }, + Instruction::Empty { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index() + }, + Instruction::Aggregate { + input: RegisterIndex::default() + .next_index() + .next_index() + .next_index(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), + func: AggregateFunction::Sum, + col_name: "__otter_temp_col_1".into(), + alias: Some("__otter_temp_col_2".into()), + }, + Instruction::Empty { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index() + .next_index() + }, + Instruction::Project { + input: RegisterIndex::default(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index() + .next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col2".into(), + }), + alias: None + }, + Instruction::Project { + input: RegisterIndex::default(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index() + .next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col3".into(), + }), + alias: None + }, + Instruction::Project { + input: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index() + .next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "__otter_temp_col_2".into(), + }), + alias: Some("sos".into()) + }, + Instruction::Return { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index() + .next_index(), + } + ], + instrs + ) + }, + ); + } +} + +#[cfg(test)] +mod expr_codegen_tests { + use sqlparser::{ast, dialect::GenericDialect, parser::Parser, tokenizer::Tokenizer}; + + use pretty_assertions::{assert_eq, assert_ne}; + + use crate::{ + codegen::{codegen_expr, CodegenContext, IntermediateExpr, IntermediateExprAgg}, + expr::{agg::AggregateFunction, BinOp, Expr, ExprError, UnOp}, + identifier::ColumnRef, + value::Value, + }; + + #[test] + fn conversion_from_ast() { + fn parse_expr(s: &str) -> ast::Expr { + let dialect = GenericDialect {}; + let mut tokenizer = Tokenizer::new(&dialect, s); + let tokens = tokenizer.tokenize().unwrap(); + let mut parser = Parser::new(tokens, &dialect); + parser.parse_expr().unwrap() + } + + fn codegen_expr_wrapper_no_agg(expr_ast: ast::Expr) -> Result { + let mut ctx = CodegenContext::new(); + match codegen_expr(expr_ast, &mut ctx)? { + IntermediateExpr::Agg(_) => panic!("Expected unaggregated expression"), + IntermediateExpr::NonAgg(expr) => Ok(expr), + } + } + + fn codegen_expr_wrapper_agg(expr_ast: ast::Expr) -> Result { + let mut ctx = CodegenContext::new(); + match codegen_expr(expr_ast, &mut ctx)? { + IntermediateExpr::Agg(agg) => Ok(agg), + IntermediateExpr::NonAgg(_) => panic!("Expected aggregated expression"), + } + } + + assert_eq!( + codegen_expr_wrapper_no_agg(parse_expr("abc")), + Ok(Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "abc".into() + })) + ); + + assert_ne!( + codegen_expr_wrapper_no_agg(parse_expr("abc")), + Ok(Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "cab".into() + })) + ); + + assert_eq!( + codegen_expr_wrapper_no_agg(parse_expr("table1.col1")), + Ok(Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: Some("table1".into()), + col_name: "col1".into() + })) + ); + + assert_eq!( + codegen_expr_wrapper_no_agg(parse_expr("schema1.table1.col1")), + Ok(Expr::ColumnRef(ColumnRef { + schema_name: Some("schema1".into()), + table_name: Some("table1".into()), + col_name: "col1".into() + })) + ); + + assert_eq!( + codegen_expr_wrapper_no_agg(parse_expr("5 IS NULL")), + Ok(Expr::Unary { + op: UnOp::IsNull, + operand: Box::new(Expr::Value(Value::Int64(5))) + }) + ); + + assert_eq!( + codegen_expr_wrapper_no_agg(parse_expr("1 IS TRUE")), + Ok(Expr::Unary { + op: UnOp::IsTrue, + operand: Box::new(Expr::Value(Value::Int64(1))) + }) + ); + + assert_eq!( + codegen_expr_wrapper_no_agg(parse_expr("4 BETWEEN 3 AND 5")), + Ok(Expr::Binary { + left: Box::new(Expr::Binary { + left: Box::new(Expr::Value(Value::Int64(3))), + op: BinOp::LessThanOrEqual, + right: Box::new(Expr::Value(Value::Int64(4))) + }), + op: BinOp::And, + right: Box::new(Expr::Binary { + left: Box::new(Expr::Value(Value::Int64(4))), + op: BinOp::LessThanOrEqual, + right: Box::new(Expr::Value(Value::Int64(5))) + }) + }) + ); + + assert_eq!( + codegen_expr_wrapper_no_agg(parse_expr("4 NOT BETWEEN 3 AND 5")), + Ok(Expr::Unary { + op: UnOp::Not, + operand: Box::new(Expr::Binary { + left: Box::new(Expr::Binary { + left: Box::new(Expr::Value(Value::Int64(3))), + op: BinOp::LessThanOrEqual, + right: Box::new(Expr::Value(Value::Int64(4))) + }), + op: BinOp::And, + right: Box::new(Expr::Binary { + left: Box::new(Expr::Value(Value::Int64(4))), + op: BinOp::LessThanOrEqual, + right: Box::new(Expr::Value(Value::Int64(5))) + }) + }) + }) + ); + + assert_eq!( + codegen_expr_wrapper_agg(parse_expr("MAX(col1)")), + Ok(IntermediateExprAgg { + pre_agg: vec![( + Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col1".into() + }), + "__otter_temp_col_1".into() + )], + agg: vec![( + AggregateFunction::Max, + "__otter_temp_col_1".into(), + "__otter_temp_col_2".into() + )], + post_agg: vec![Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "__otter_temp_col_2".into() + })], + last_alias: Some("__otter_temp_col_2".into()), + last_expr: ( + Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "__otter_temp_col_2".into() + }), + "__otter_temp_col_2".into() + ) + }) + ); + + assert_eq!( + codegen_expr_wrapper_no_agg(parse_expr("some_func(col1, 1, 'abc')")), + Ok(Expr::Function { + name: "some_func".into(), + args: vec![ + Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col1".into() + }), + Expr::Value(Value::Int64(1)), + Expr::Value(Value::String("abc".to_owned())) + ] + }) + ); + + assert_eq!( + codegen_expr_wrapper_agg(parse_expr("COUNT(*)")), + Ok(IntermediateExprAgg { + pre_agg: vec![(Expr::Wildcard, "__otter_temp_col_1".into())], + agg: vec![( + AggregateFunction::Count, + "__otter_temp_col_1".into(), + "__otter_temp_col_2".into() + )], + post_agg: vec![Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "__otter_temp_col_2".into() + })], + last_alias: Some("__otter_temp_col_2".into()), + last_expr: ( + Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "__otter_temp_col_2".into() + }), + "__otter_temp_col_2".into() + ) + }) + ); + } +} + +#[cfg(test)] +mod expr_eval_tests { + use sqlparser::{ + ast::{ColumnOption, ColumnOptionDef, DataType}, + dialect::GenericDialect, + parser::Parser, + tokenizer::Tokenizer, + }; + + use crate::{ + column::Column, + expr::{eval::ExprExecError, BinOp, Expr, UnOp}, + table::{Row, Table}, + value::{Value, ValueBinaryOpError, ValueUnaryOpError}, + }; + + use super::{codegen_expr, CodegenContext, IntermediateExpr}; + + fn str_to_expr(s: &str) -> Expr { + let dialect = GenericDialect {}; + let mut tokenizer = Tokenizer::new(&dialect, s); + let tokens = tokenizer.tokenize().unwrap(); + let mut parser = Parser::new(tokens, &dialect); + let mut ctx = CodegenContext::new(); + match codegen_expr(parser.parse_expr().unwrap(), &mut ctx).unwrap() { + IntermediateExpr::NonAgg(expr) => expr, + IntermediateExpr::Agg(_) => panic!("Did not expect aggregate expression here"), + } + } + + fn exec_expr_no_context(expr: Expr) -> Result { + let mut table = Table::new_temp(0); + table.new_row(vec![]); + Expr::execute(&expr, &table, table.all_data()[0].to_shared()) + } + + fn exec_str_no_context(s: &str) -> Result { + let expr = str_to_expr(s); + exec_expr_no_context(expr) + } + + fn exec_str_with_context(s: &str, table: &Table, row: &Row) -> Result { + let expr = str_to_expr(s); + Expr::execute(&expr, table, row.to_shared()) + } + + #[test] + fn exec_value() { + assert_eq!(exec_str_no_context("NULL"), Ok(Value::Null)); + + assert_eq!(exec_str_no_context("true"), Ok(Value::Bool(true))); + + assert_eq!(exec_str_no_context("1"), Ok(Value::Int64(1))); + + assert_eq!(exec_str_no_context("1.1"), Ok(Value::Float64(1.1.into()))); + + assert_eq!(exec_str_no_context(".1"), Ok(Value::Float64(0.1.into()))); + + assert_eq!( + exec_str_no_context("'str'"), + Ok(Value::String("str".to_owned())) + ); + } + + #[test] + fn exec_logical() { + assert_eq!(exec_str_no_context("true and true"), Ok(Value::Bool(true))); + assert_eq!( + exec_str_no_context("true and false"), + Ok(Value::Bool(false)) + ); + assert_eq!( + exec_str_no_context("false and true"), + Ok(Value::Bool(false)) + ); + assert_eq!( + exec_str_no_context("false and false"), + Ok(Value::Bool(false)) + ); + assert_eq!( + exec_str_no_context("false and 10"), + Err(ValueBinaryOpError { + operator: BinOp::And, + values: (Value::Bool(false), Value::Int64(10)) + } + .into()) + ); + assert_eq!( + exec_str_no_context("10 and false"), + Err(ValueBinaryOpError { + operator: BinOp::And, + values: (Value::Int64(10), Value::Bool(false)) + } + .into()) + ); + + assert_eq!(exec_str_no_context("true or true"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("true or false"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("false or true"), Ok(Value::Bool(true))); + assert_eq!( + exec_str_no_context("false or false"), + Ok(Value::Bool(false)) + ); + assert_eq!( + exec_str_no_context("true or 10"), + Err(ValueBinaryOpError { + operator: BinOp::Or, + values: (Value::Bool(true), Value::Int64(10)) + } + .into()) + ); + assert_eq!( + exec_str_no_context("10 or true"), + Err(ValueBinaryOpError { + operator: BinOp::Or, + values: (Value::Int64(10), Value::Bool(true)) + } + .into()) + ); + } + + #[test] + fn exec_arithmetic() { + assert_eq!(exec_str_no_context("1 + 1"), Ok(Value::Int64(2))); + assert_eq!( + exec_str_no_context("1.1 + 1.1"), + Ok(Value::Float64(2.2.into())) + ); + + // this applies to all binary ops + assert_eq!( + exec_str_no_context("1 + 1.1"), + Err(ValueBinaryOpError { + operator: BinOp::Plus, + values: (Value::Int64(1), Value::Float64(1.1.into())) + } + .into()) + ); + + assert_eq!(exec_str_no_context("4 - 2"), Ok(Value::Int64(2))); + assert_eq!(exec_str_no_context("4 - 6"), Ok(Value::Int64(-2))); + assert_eq!( + exec_str_no_context("4.5 - 2.2"), + Ok(Value::Float64(2.3.into())) + ); + + assert_eq!(exec_str_no_context("4 * 2"), Ok(Value::Int64(8))); + assert_eq!( + exec_str_no_context("0.5 * 2.2"), + Ok(Value::Float64(1.1.into())) + ); + + assert_eq!(exec_str_no_context("4 / 2"), Ok(Value::Int64(2))); + assert_eq!(exec_str_no_context("4 / 3"), Ok(Value::Int64(1))); + assert_eq!( + exec_str_no_context("4.0 / 2.0"), + Ok(Value::Float64(2.0.into())) + ); + assert_eq!( + exec_str_no_context("5.1 / 2.5"), + Ok(Value::Float64(2.04.into())) + ); + + assert_eq!(exec_str_no_context("5 % 2"), Ok(Value::Int64(1))); + assert_eq!( + exec_str_no_context("5.5 % 2.5"), + Ok(Value::Float64(0.5.into())) + ); + } + + #[test] + fn exec_comparison() { + assert_eq!(exec_str_no_context("1 = 1"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("1 = 2"), Ok(Value::Bool(false))); + assert_eq!(exec_str_no_context("1 != 2"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("1.1 = 1.1"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("1.2 = 1.22"), Ok(Value::Bool(false))); + assert_eq!(exec_str_no_context("1.2 != 1.22"), Ok(Value::Bool(true))); + + assert_eq!(exec_str_no_context("1 < 2"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("1 < 1"), Ok(Value::Bool(false))); + assert_eq!(exec_str_no_context("1 <= 2"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("1 <= 1"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("3 > 2"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("3 > 3"), Ok(Value::Bool(false))); + assert_eq!(exec_str_no_context("3 >= 2"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("3 >= 3"), Ok(Value::Bool(true))); + } + + #[test] + fn exec_pattern_match() { + assert_eq!( + exec_str_no_context("'my name is yoshikage kira' LIKE 'kira'"), + Ok(Value::Bool(true)) + ); + assert_eq!( + exec_str_no_context("'my name is yoshikage kira' LIKE 'KIRA'"), + Ok(Value::Bool(false)) + ); + assert_eq!( + exec_str_no_context("'my name is yoshikage kira' LIKE 'kira yoshikage'"), + Ok(Value::Bool(false)) + ); + + assert_eq!( + exec_str_no_context("'my name is Yoshikage Kira' ILIKE 'kira'"), + Ok(Value::Bool(true)) + ); + assert_eq!( + exec_str_no_context("'my name is Yoshikage Kira' ILIKE 'KIRA'"), + Ok(Value::Bool(true)) + ); + assert_eq!( + exec_str_no_context("'my name is Yoshikage Kira' ILIKE 'KIRAA'"), + Ok(Value::Bool(false)) + ); + } + + #[test] + fn exec_unary() { + assert_eq!(exec_str_no_context("+1"), Ok(Value::Int64(1))); + assert_eq!(exec_str_no_context("+ -1"), Ok(Value::Int64(-1))); + assert_eq!(exec_str_no_context("-1"), Ok(Value::Int64(-1))); + assert_eq!(exec_str_no_context("- -1"), Ok(Value::Int64(1))); + assert_eq!(exec_str_no_context("not true"), Ok(Value::Bool(false))); + assert_eq!(exec_str_no_context("not false"), Ok(Value::Bool(true))); + + assert_eq!(exec_str_no_context("true is true"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("false is false"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("false is true"), Ok(Value::Bool(false))); + assert_eq!(exec_str_no_context("true is false"), Ok(Value::Bool(false))); + assert_eq!( + exec_str_no_context("1 is true"), + Err(ValueUnaryOpError { + operator: UnOp::IsTrue, + value: Value::Int64(1) + } + .into()) + ); + + assert_eq!(exec_str_no_context("NULL is NULL"), Ok(Value::Bool(true))); + assert_eq!( + exec_str_no_context("NULL is not NULL"), + Ok(Value::Bool(false)) + ); + assert_eq!(exec_str_no_context("1 is NULL"), Ok(Value::Bool(false))); + assert_eq!(exec_str_no_context("1 is not NULL"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("0 is not NULL"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("'' is not NULL"), Ok(Value::Bool(true))); + } + + #[test] + fn exec_wildcard() { + assert_eq!( + exec_expr_no_context(Expr::Wildcard), + Err(ExprExecError::CannotExecute(Expr::Wildcard)) + ); + } + + #[test] + fn exec_column_ref() { + let mut table = Table::new( + "table1".into(), + vec![ + Column::new( + "col1".into(), + DataType::Int(None), + vec![ColumnOptionDef { + name: None, + option: ColumnOption::Unique { is_primary: true }, + }], + false, + ), + Column::new( + "col2".into(), + DataType::Int(None), + vec![ColumnOptionDef { + name: None, + option: ColumnOption::Unique { is_primary: false }, + }], + false, + ), + Column::new("col3".into(), DataType::String, vec![], false), + ], + ); + table.new_row(vec![ + Value::Int64(4), + Value::Int64(10), + Value::String("brr".to_owned()), + ]); + + assert_eq!( + table.all_data(), + vec![Row::new(vec![ + Value::Int64(4), + Value::Int64(10), + Value::String("brr".to_owned()) + ])] + ); + + assert_eq!( + exec_str_with_context("col1", &table, &table.all_data()[0]), + Ok(Value::Int64(4)) + ); + + assert_eq!( + exec_str_with_context("col3", &table, &table.all_data()[0]), + Ok(Value::String("brr".to_owned())) + ); + + assert_eq!( + exec_str_with_context("col1 = 4", &table, &table.all_data()[0]), + Ok(Value::Bool(true)) + ); + + assert_eq!( + exec_str_with_context("col1 + 1", &table, &table.all_data()[0]), + Ok(Value::Int64(5)) + ); + + assert_eq!( + exec_str_with_context("col1 + col2", &table, &table.all_data()[0]), + Ok(Value::Int64(14)) + ); + + assert_eq!( + exec_str_with_context( + "col1 + col2 = 10 or col1 * col2 = 40", + &table, + &table.all_data()[0] + ), + Ok(Value::Bool(true)) + ); } } diff --git a/src/expr/agg.rs b/src/expr/agg.rs new file mode 100644 index 0000000..0f3d9c7 --- /dev/null +++ b/src/expr/agg.rs @@ -0,0 +1,36 @@ +use super::ExprError; +use std::fmt::Display; + +#[derive(Debug, Clone, PartialEq, Eq)] +/// Functions that reduce an entire column to a single value. +pub enum AggregateFunction { + Count, + Max, + Sum, +} + +impl AggregateFunction { + /// Get an aggregation function by name. + pub fn from_name(name: &str) -> Result { + match name.to_lowercase().as_str() { + "count" => Ok(Self::Count), + "max" => Ok(Self::Max), + "sum" => Ok(Self::Sum), + _ => Err(ExprError::UnknownAggregateFunction(name.to_owned())), + } + } +} + +impl Display for AggregateFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Agg({})", + match self { + AggregateFunction::Count => "count", + AggregateFunction::Max => "max", + AggregateFunction::Sum => "sum", + } + ) + } +} diff --git a/src/expr/eval.rs b/src/expr/eval.rs index 7739186..1ae01d3 100644 --- a/src/expr/eval.rs +++ b/src/expr/eval.rs @@ -153,335 +153,3 @@ impl Display for ExprExecError { } impl Error for ExprExecError {} - -#[cfg(test)] -mod test { - use sqlparser::{ - ast::{ColumnOption, ColumnOptionDef, DataType}, - dialect::GenericDialect, - parser::Parser, - tokenizer::Tokenizer, - }; - - use crate::{ - column::Column, - expr::{BinOp, Expr, UnOp}, - table::{Row, Table}, - value::{Value, ValueBinaryOpError, ValueUnaryOpError}, - }; - - use super::ExprExecError; - - fn str_to_expr(s: &str) -> Expr { - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, s); - let tokens = tokenizer.tokenize().unwrap(); - let mut parser = Parser::new(tokens, &dialect); - parser.parse_expr().unwrap().try_into().unwrap() - } - - fn exec_expr_no_context(expr: Expr) -> Result { - let mut table = Table::new_temp(0); - table.new_row(vec![]); - Expr::execute(&expr, &table, table.all_data()[0].to_shared()) - } - - fn exec_str_no_context(s: &str) -> Result { - let expr = str_to_expr(s); - exec_expr_no_context(expr) - } - - fn exec_str_with_context(s: &str, table: &Table, row: &Row) -> Result { - let expr = str_to_expr(s); - Expr::execute(&expr, table, row.to_shared()) - } - - #[test] - fn exec_value() { - assert_eq!(exec_str_no_context("NULL"), Ok(Value::Null)); - - assert_eq!(exec_str_no_context("true"), Ok(Value::Bool(true))); - - assert_eq!(exec_str_no_context("1"), Ok(Value::Int64(1))); - - assert_eq!(exec_str_no_context("1.1"), Ok(Value::Float64(1.1.into()))); - - assert_eq!(exec_str_no_context(".1"), Ok(Value::Float64(0.1.into()))); - - assert_eq!( - exec_str_no_context("'str'"), - Ok(Value::String("str".to_owned())) - ); - } - - #[test] - fn exec_logical() { - assert_eq!(exec_str_no_context("true and true"), Ok(Value::Bool(true))); - assert_eq!( - exec_str_no_context("true and false"), - Ok(Value::Bool(false)) - ); - assert_eq!( - exec_str_no_context("false and true"), - Ok(Value::Bool(false)) - ); - assert_eq!( - exec_str_no_context("false and false"), - Ok(Value::Bool(false)) - ); - assert_eq!( - exec_str_no_context("false and 10"), - Err(ValueBinaryOpError { - operator: BinOp::And, - values: (Value::Bool(false), Value::Int64(10)) - } - .into()) - ); - assert_eq!( - exec_str_no_context("10 and false"), - Err(ValueBinaryOpError { - operator: BinOp::And, - values: (Value::Int64(10), Value::Bool(false)) - } - .into()) - ); - - assert_eq!(exec_str_no_context("true or true"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("true or false"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("false or true"), Ok(Value::Bool(true))); - assert_eq!( - exec_str_no_context("false or false"), - Ok(Value::Bool(false)) - ); - assert_eq!( - exec_str_no_context("true or 10"), - Err(ValueBinaryOpError { - operator: BinOp::Or, - values: (Value::Bool(true), Value::Int64(10)) - } - .into()) - ); - assert_eq!( - exec_str_no_context("10 or true"), - Err(ValueBinaryOpError { - operator: BinOp::Or, - values: (Value::Int64(10), Value::Bool(true)) - } - .into()) - ); - } - - #[test] - fn exec_arithmetic() { - assert_eq!(exec_str_no_context("1 + 1"), Ok(Value::Int64(2))); - assert_eq!( - exec_str_no_context("1.1 + 1.1"), - Ok(Value::Float64(2.2.into())) - ); - - // this applies to all binary ops - assert_eq!( - exec_str_no_context("1 + 1.1"), - Err(ValueBinaryOpError { - operator: BinOp::Plus, - values: (Value::Int64(1), Value::Float64(1.1.into())) - } - .into()) - ); - - assert_eq!(exec_str_no_context("4 - 2"), Ok(Value::Int64(2))); - assert_eq!(exec_str_no_context("4 - 6"), Ok(Value::Int64(-2))); - assert_eq!( - exec_str_no_context("4.5 - 2.2"), - Ok(Value::Float64(2.3.into())) - ); - - assert_eq!(exec_str_no_context("4 * 2"), Ok(Value::Int64(8))); - assert_eq!( - exec_str_no_context("0.5 * 2.2"), - Ok(Value::Float64(1.1.into())) - ); - - assert_eq!(exec_str_no_context("4 / 2"), Ok(Value::Int64(2))); - assert_eq!(exec_str_no_context("4 / 3"), Ok(Value::Int64(1))); - assert_eq!( - exec_str_no_context("4.0 / 2.0"), - Ok(Value::Float64(2.0.into())) - ); - assert_eq!( - exec_str_no_context("5.1 / 2.5"), - Ok(Value::Float64(2.04.into())) - ); - - assert_eq!(exec_str_no_context("5 % 2"), Ok(Value::Int64(1))); - assert_eq!( - exec_str_no_context("5.5 % 2.5"), - Ok(Value::Float64(0.5.into())) - ); - } - - #[test] - fn exec_comparison() { - assert_eq!(exec_str_no_context("1 = 1"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("1 = 2"), Ok(Value::Bool(false))); - assert_eq!(exec_str_no_context("1 != 2"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("1.1 = 1.1"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("1.2 = 1.22"), Ok(Value::Bool(false))); - assert_eq!(exec_str_no_context("1.2 != 1.22"), Ok(Value::Bool(true))); - - assert_eq!(exec_str_no_context("1 < 2"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("1 < 1"), Ok(Value::Bool(false))); - assert_eq!(exec_str_no_context("1 <= 2"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("1 <= 1"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("3 > 2"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("3 > 3"), Ok(Value::Bool(false))); - assert_eq!(exec_str_no_context("3 >= 2"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("3 >= 3"), Ok(Value::Bool(true))); - } - - #[test] - fn exec_pattern_match() { - assert_eq!( - exec_str_no_context("'my name is yoshikage kira' LIKE 'kira'"), - Ok(Value::Bool(true)) - ); - assert_eq!( - exec_str_no_context("'my name is yoshikage kira' LIKE 'KIRA'"), - Ok(Value::Bool(false)) - ); - assert_eq!( - exec_str_no_context("'my name is yoshikage kira' LIKE 'kira yoshikage'"), - Ok(Value::Bool(false)) - ); - - assert_eq!( - exec_str_no_context("'my name is Yoshikage Kira' ILIKE 'kira'"), - Ok(Value::Bool(true)) - ); - assert_eq!( - exec_str_no_context("'my name is Yoshikage Kira' ILIKE 'KIRA'"), - Ok(Value::Bool(true)) - ); - assert_eq!( - exec_str_no_context("'my name is Yoshikage Kira' ILIKE 'KIRAA'"), - Ok(Value::Bool(false)) - ); - } - - #[test] - fn exec_unary() { - assert_eq!(exec_str_no_context("+1"), Ok(Value::Int64(1))); - assert_eq!(exec_str_no_context("+ -1"), Ok(Value::Int64(-1))); - assert_eq!(exec_str_no_context("-1"), Ok(Value::Int64(-1))); - assert_eq!(exec_str_no_context("- -1"), Ok(Value::Int64(1))); - assert_eq!(exec_str_no_context("not true"), Ok(Value::Bool(false))); - assert_eq!(exec_str_no_context("not false"), Ok(Value::Bool(true))); - - assert_eq!(exec_str_no_context("true is true"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("false is false"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("false is true"), Ok(Value::Bool(false))); - assert_eq!(exec_str_no_context("true is false"), Ok(Value::Bool(false))); - assert_eq!( - exec_str_no_context("1 is true"), - Err(ValueUnaryOpError { - operator: UnOp::IsTrue, - value: Value::Int64(1) - } - .into()) - ); - - assert_eq!(exec_str_no_context("NULL is NULL"), Ok(Value::Bool(true))); - assert_eq!( - exec_str_no_context("NULL is not NULL"), - Ok(Value::Bool(false)) - ); - assert_eq!(exec_str_no_context("1 is NULL"), Ok(Value::Bool(false))); - assert_eq!(exec_str_no_context("1 is not NULL"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("0 is not NULL"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("'' is not NULL"), Ok(Value::Bool(true))); - } - - #[test] - fn exec_wildcard() { - assert_eq!( - exec_expr_no_context(Expr::Wildcard), - Err(ExprExecError::CannotExecute(Expr::Wildcard)) - ); - } - - #[test] - fn exec_column_ref() { - let mut table = Table::new( - "table1".into(), - vec![ - Column::new( - "col1".into(), - DataType::Int(None), - vec![ColumnOptionDef { - name: None, - option: ColumnOption::Unique { is_primary: true }, - }], - false, - ), - Column::new( - "col2".into(), - DataType::Int(None), - vec![ColumnOptionDef { - name: None, - option: ColumnOption::Unique { is_primary: false }, - }], - false, - ), - Column::new("col3".into(), DataType::String, vec![], false), - ], - ); - table.new_row(vec![ - Value::Int64(4), - Value::Int64(10), - Value::String("brr".to_owned()), - ]); - - assert_eq!( - table.all_data(), - vec![Row::new(vec![ - Value::Int64(4), - Value::Int64(10), - Value::String("brr".to_owned()) - ])] - ); - - assert_eq!( - exec_str_with_context("col1", &table, &table.all_data()[0]), - Ok(Value::Int64(4)) - ); - - assert_eq!( - exec_str_with_context("col3", &table, &table.all_data()[0]), - Ok(Value::String("brr".to_owned())) - ); - - assert_eq!( - exec_str_with_context("col1 = 4", &table, &table.all_data()[0]), - Ok(Value::Bool(true)) - ); - - assert_eq!( - exec_str_with_context("col1 + 1", &table, &table.all_data()[0]), - Ok(Value::Int64(5)) - ); - - assert_eq!( - exec_str_with_context("col1 + col2", &table, &table.all_data()[0]), - Ok(Value::Int64(14)) - ); - - assert_eq!( - exec_str_with_context( - "col1 + col2 = 10 or col1 * col2 = 40", - &table, - &table.all_data()[0] - ), - Ok(Value::Bool(true)) - ); - } -} diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 21e97fd..c6151a6 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -10,6 +10,7 @@ use crate::{ BoundedString, }; +pub mod agg; pub mod eval; /// An expression @@ -130,98 +131,6 @@ impl Display for UnOp { } } -impl TryFrom for Expr { - type Error = ExprError; - fn try_from(expr_ast: ast::Expr) -> Result { - match expr_ast { - ast::Expr::Identifier(i) => Ok(Expr::ColumnRef(vec![i].try_into()?)), - ast::Expr::CompoundIdentifier(i) => Ok(Expr::ColumnRef(i.try_into()?)), - ast::Expr::IsFalse(e) => Ok(Expr::Unary { - op: UnOp::IsFalse, - operand: Box::new((*e).try_into()?), - }), - ast::Expr::IsTrue(e) => Ok(Expr::Unary { - op: UnOp::IsTrue, - operand: Box::new((*e).try_into()?), - }), - ast::Expr::IsNull(e) => Ok(Expr::Unary { - op: UnOp::IsNull, - operand: Box::new((*e).try_into()?), - }), - ast::Expr::IsNotNull(e) => Ok(Expr::Unary { - op: UnOp::IsNotNull, - operand: Box::new((*e).try_into()?), - }), - ast::Expr::Between { - expr, - negated, - low, - high, - } => { - let expr: Box = Box::new((*expr).try_into()?); - let left = Box::new((*low).try_into()?); - let right = Box::new((*high).try_into()?); - let between = Expr::Binary { - left: Box::new(Expr::Binary { - left, - op: BinOp::LessThanOrEqual, - right: expr.clone(), - }), - op: BinOp::And, - right: Box::new(Expr::Binary { - left: expr, - op: BinOp::LessThanOrEqual, - right, - }), - }; - if negated { - Ok(Expr::Unary { - op: UnOp::Not, - operand: Box::new(between), - }) - } else { - Ok(between) - } - } - ast::Expr::BinaryOp { left, op, right } => Ok(Expr::Binary { - left: Box::new((*left).try_into()?), - op: op.try_into()?, - right: Box::new((*right).try_into()?), - }), - ast::Expr::UnaryOp { op, expr } => Ok(Expr::Unary { - op: op.try_into()?, - operand: Box::new((*expr).try_into()?), - }), - ast::Expr::Value(v) => Ok(Expr::Value(v.try_into()?)), - ast::Expr::Function(ref f) => Ok(Expr::Function { - name: f.name.to_string().as_str().into(), - args: f - .args - .iter() - .map(|arg| match arg { - ast::FunctionArg::Unnamed(arg_expr) => match arg_expr { - ast::FunctionArgExpr::Expr(e) => Ok(e.clone().try_into()?), - ast::FunctionArgExpr::Wildcard => Ok(Expr::Wildcard), - ast::FunctionArgExpr::QualifiedWildcard(_) => Err(ExprError::Expr { - reason: "Qualified wildcards are not supported yet", - expr: expr_ast.clone(), - }), - }, - ast::FunctionArg::Named { .. } => Err(ExprError::Expr { - reason: "Named function arguments are not supported", - expr: expr_ast.clone(), - }), - }) - .collect::, _>>()?, - }), - _ => Err(ExprError::Expr { - reason: "Unsupported expression", - expr: expr_ast, - }), - } - } -} - impl TryFrom for BinOp { type Error = ExprError; fn try_from(op: ast::BinaryOperator) -> Result { @@ -284,6 +193,7 @@ pub enum ExprError { }, Value(ValueError), Identifier(IdentifierError), + UnknownAggregateFunction(String), } impl Display for ExprError { @@ -300,6 +210,9 @@ impl Display for ExprError { } ExprError::Value(v) => write!(f, "{}", v), ExprError::Identifier(v) => write!(f, "{}", v), + ExprError::UnknownAggregateFunction(agg) => { + write!(f, "Unsupported Aggregate Function: {}", agg) + } } } } @@ -317,150 +230,3 @@ impl From for ExprError { } impl Error for ExprError {} - -#[cfg(test)] -mod tests { - use sqlparser::{ast, dialect::GenericDialect, parser::Parser, tokenizer::Tokenizer}; - - use crate::{ - expr::{BinOp, Expr, UnOp}, - identifier::ColumnRef, - value::Value, - }; - - #[test] - fn conversion_from_ast() { - fn parse_expr(s: &str) -> ast::Expr { - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, s); - let tokens = tokenizer.tokenize().unwrap(); - let mut parser = Parser::new(tokens, &dialect); - parser.parse_expr().unwrap() - } - - assert_eq!( - parse_expr("abc").try_into(), - Ok(Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: None, - col_name: "abc".into() - })) - ); - - assert_ne!( - parse_expr("abc").try_into(), - Ok(Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: None, - col_name: "cab".into() - })) - ); - - assert_eq!( - parse_expr("table1.col1").try_into(), - Ok(Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: Some("table1".into()), - col_name: "col1".into() - })) - ); - - assert_eq!( - parse_expr("schema1.table1.col1").try_into(), - Ok(Expr::ColumnRef(ColumnRef { - schema_name: Some("schema1".into()), - table_name: Some("table1".into()), - col_name: "col1".into() - })) - ); - - assert_eq!( - parse_expr("5 IS NULL").try_into(), - Ok(Expr::Unary { - op: UnOp::IsNull, - operand: Box::new(Expr::Value(Value::Int64(5))) - }) - ); - - assert_eq!( - parse_expr("1 IS TRUE").try_into(), - Ok(Expr::Unary { - op: UnOp::IsTrue, - operand: Box::new(Expr::Value(Value::Int64(1))) - }) - ); - - assert_eq!( - parse_expr("4 BETWEEN 3 AND 5").try_into(), - Ok(Expr::Binary { - left: Box::new(Expr::Binary { - left: Box::new(Expr::Value(Value::Int64(3))), - op: BinOp::LessThanOrEqual, - right: Box::new(Expr::Value(Value::Int64(4))) - }), - op: BinOp::And, - right: Box::new(Expr::Binary { - left: Box::new(Expr::Value(Value::Int64(4))), - op: BinOp::LessThanOrEqual, - right: Box::new(Expr::Value(Value::Int64(5))) - }) - }) - ); - - assert_eq!( - parse_expr("4 NOT BETWEEN 3 AND 5").try_into(), - Ok(Expr::Unary { - op: UnOp::Not, - operand: Box::new(Expr::Binary { - left: Box::new(Expr::Binary { - left: Box::new(Expr::Value(Value::Int64(3))), - op: BinOp::LessThanOrEqual, - right: Box::new(Expr::Value(Value::Int64(4))) - }), - op: BinOp::And, - right: Box::new(Expr::Binary { - left: Box::new(Expr::Value(Value::Int64(4))), - op: BinOp::LessThanOrEqual, - right: Box::new(Expr::Value(Value::Int64(5))) - }) - }) - }) - ); - - assert_eq!( - parse_expr("MAX(col1)").try_into(), - Ok(Expr::Function { - name: "MAX".into(), - args: vec![Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: None, - col_name: "col1".into() - })] - }) - ); - - assert_eq!( - parse_expr("some_func(col1, 1, 'abc')").try_into(), - Ok(Expr::Function { - name: "some_func".into(), - args: vec![ - Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: None, - col_name: "col1".into() - }), - Expr::Value(Value::Int64(1)), - Expr::Value(Value::String("abc".to_owned())) - ] - }) - ); - - assert_eq!( - parse_expr("COUNT(*)").try_into(), - Ok(Expr::Function { - name: "COUNT".into(), - args: vec![Expr::Wildcard] - }) - ); - } -} diff --git a/src/ic.rs b/src/ir.rs similarity index 97% rename from src/ic.rs rename to src/ir.rs index 3d0fc1a..4b4d0b9 100644 --- a/src/ic.rs +++ b/src/ir.rs @@ -5,9 +5,8 @@ use fmt_derive::{Debug, Display}; use sqlparser::ast::{ColumnOptionDef, DataType}; use crate::{ - expr::Expr, + expr::{agg::AggregateFunction, Expr}, identifier::{SchemaRef, TableRef}, - value::Value, vm::RegisterIndex, BoundedString, }; @@ -21,9 +20,6 @@ pub struct IntermediateCode { /// The instruction set of OtterSQL. #[derive(Display, Debug, Clone, PartialEq)] pub enum Instruction { - /// Load a [`Value`] into a register. - Value { index: RegisterIndex, value: Value }, - /// Load a [`Expr`] into a register. Expr { index: RegisterIndex, expr: Expr }, @@ -68,12 +64,32 @@ pub enum Instruction { alias: Option, }, - /// Group the [`Register::TableRef`](`crate::vm::Register::TableRef`) at `index` by the given expression. + Aggregate { + input: RegisterIndex, + output: RegisterIndex, + func: AggregateFunction, + /// Column in input to aggregate. + col_name: BoundedString, + #[display( + "{}", + match alias { + None => "None".to_owned(), + Some(alias) => format!("{}", alias) + } + )] + alias: Option, + }, + + /// Group the [`Register::TableRef`](`crate::vm::Register::TableRef`) at `input` by the given expression. /// - /// This will result in a [`Register::GroupedTable`](`crate::vm::Register::GroupedTable`) being stored at the `index` register. + /// This will result in a [`Register::GroupedTable`](`crate::vm::Register::GroupedTable`) being stored at the `output` register. /// /// Must be added before any projections so as to catch errors in column selections. - GroupBy { index: RegisterIndex, expr: Expr }, + GroupBy { + input: RegisterIndex, + output: RegisterIndex, + expr: Expr, + }, /// Order the [`Register::TableRef`](`crate::vm::Register::TableRef`) at `index` by the given expression. /// diff --git a/src/lib.rs b/src/lib.rs index bbafd36..1ab0c9f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,7 @@ pub mod codegen; pub mod column; pub mod database; pub mod expr; -pub mod ic; +pub mod ir; pub mod identifier; pub mod parser; pub mod schema; @@ -18,7 +18,7 @@ pub mod vm; pub use column::Column; pub use database::Database; -pub use ic::{Instruction, IntermediateCode}; +pub use ir::{Instruction, IntermediateCode}; pub use identifier::BoundedString; pub use table::Table; pub use value::Value; diff --git a/src/vm.rs b/src/vm.rs index 8396995..946d720 100644 --- a/src/vm.rs +++ b/src/vm.rs @@ -13,8 +13,8 @@ use crate::codegen::{codegen_ast, CodegenError}; use crate::column::Column; use crate::expr::eval::ExprExecError; use crate::expr::Expr; -use crate::ic::{Instruction, IntermediateCode}; use crate::identifier::{ColumnRef, TableRef}; +use crate::ir::{Instruction, IntermediateCode}; use crate::parser::parse; use crate::schema::Schema; use crate::table::{Row, RowShared, Table}; @@ -133,10 +133,6 @@ impl VirtualMachine { fn execute_instr(&mut self, instr: &Instruction) -> Result, RuntimeError> { let _ = &self.database; match instr { - Instruction::Value { index, value } => { - self.registers - .insert(*index, Register::Value(value.clone())); - } Instruction::Expr { index, expr } => { self.registers.insert(*index, Register::Expr(expr.clone())); } @@ -328,7 +324,18 @@ impl VirtualMachine { return Err(RuntimeError::RegisterNotATable("project", reg.clone())) } }, - Instruction::GroupBy { index: _, expr: _ } => todo!("group by is not implemented yet"), + Instruction::Aggregate { + input: _, + output: _, + func: _, + col_name: _, + alias: _, + } => todo!("aggregate is not implemented yet"), + Instruction::GroupBy { + input: _, + output: _, + expr: _, + } => todo!("group by is not implemented yet"), Instruction::Order { index, expr,