From 285fe5fd119f53c064c3f830534af0e82930fcd6 Mon Sep 17 00:00:00 2001 From: AlSchlo <79570602+AlSchlo@users.noreply.github.com> Date: Sat, 1 Mar 2025 02:24:30 +0100 Subject: [PATCH] feat: DSL Array Pattern Match & Simplification (#36) # Add List Pattern Matching to the HIR and Evaluation Engine ## Summary This PR adds list pattern matching capabilities to the language. It introduces two new pattern variants: `EmptyArray` for matching empty arrays, and `ArrayDecomp` for decomposing arrays into head and tail components. ## Changes - Added new pattern variants to the `Pattern` enum in the HIR (engine) & AST (parsing). - Extended the pattern matching logic in `match_pattern` to handle list patterns - Added comprehensive tests demonstrating list pattern matching functionality - Simplified syntax elements as agreed during the design phase ## Implementation Details ```rust // New pattern variants pub enum Pattern { // ... existing variants EmptyArray, ArrayDecomp(Box, Box), } ``` The pattern matching logic was extended to handle: - Empty arrays with the `EmptyArray` pattern - Non-empty arrays with the `ArrayDecomp` pattern, which splits arrays into head and tail ## Testing Added a test that shows a practical application of list pattern matching: a recursive sum function defined as: ``` fn sum(arr: [I64]): I64 = match arr | [] -> 0 \ [head .. tail] -> head + sum(tail) ``` The test verifies that this function correctly computes: - `sum([]) = 0` - `sum([42]) = 42` - `sum([1, 2, 3]) = 6` ## Related Changes As part of this PR, I also simplified some syntax elements: - Changed the composition operator from `->` to `.` - Changed the match arrow from `=>` to `->` --- optd-dsl/src/analyzer/hir.rs | 2 + optd-dsl/src/cli/basic.op | 32 ++--- optd-dsl/src/engine/eval/binary.rs | 2 +- optd-dsl/src/engine/eval/expr.rs | 67 +++++++++ optd-dsl/src/engine/eval/match.rs | 30 ++++ optd-dsl/src/lexer/lex.rs | 2 - optd-dsl/src/lexer/tokens.rs | 4 - optd-dsl/src/parser/adt.rs | 18 +-- optd-dsl/src/parser/ast.rs | 6 +- optd-dsl/src/parser/expr.rs | 212 ++++++++++++++++++++++++----- optd-dsl/src/parser/function.rs | 4 +- optd-dsl/src/parser/module.rs | 34 ++--- optd-dsl/src/parser/pattern.rs | 102 ++++++++++++++ optd-dsl/src/parser/type.rs | 12 +- 14 files changed, 435 insertions(+), 92 deletions(-) diff --git a/optd-dsl/src/analyzer/hir.rs b/optd-dsl/src/analyzer/hir.rs index 95e0de9..c3ebb03 100644 --- a/optd-dsl/src/analyzer/hir.rs +++ b/optd-dsl/src/analyzer/hir.rs @@ -101,6 +101,8 @@ pub enum Pattern { Struct(Identifier, Vec), Operator(Operator), Wildcard, + EmptyArray, + ArrayDecomp(Box, Box), } /// Match arm combining pattern and expression diff --git a/optd-dsl/src/cli/basic.op b/optd-dsl/src/cli/basic.op index 829b922..187b432 100644 --- a/optd-dsl/src/cli/basic.op +++ b/optd-dsl/src/cli/basic.op @@ -1,18 +1,18 @@ data LogicalProps(schema_len: I64) -data Scalar with +data Scalar = | ColumnRef(idx: Int64) - | Literal with + | Literal = | IntLiteral(value: Int64) | StringLiteral(value: String) | BoolLiteral(value: Bool) \ NullLiteral - | Arithmetic with + | Arithmetic = | Mult(left: Scalar, right: Scalar) | Add(left: Scalar, right: Scalar) | Sub(left: Scalar, right: Scalar) \ Div(left: Scalar, right: Scalar) - | Predicate with + | Predicate = | And(children: [Predicate]) | Or(children: [Predicate]) | Not(child: Predicate) @@ -24,18 +24,18 @@ data Scalar with | GreaterThanEqual(left: Scalar, right: Scalar) | IsNull(expr: Scalar) \ IsNotNull(expr: Scalar) - | Function with + | Function = | Cast(expr: Scalar, target_type: String) | Substring(str: Scalar, start: Scalar, length: Scalar) \ Concat(args: [Scalar]) - \ AggregateExpr with + \ AggregateExpr = | Sum(expr: Scalar) | Count(expr: Scalar) | Min(expr: Scalar) | Max(expr: Scalar) \ Avg(expr: Scalar) -data Logical with +data Logical = | Scan(table_name: String) | Filter(child: Logical, cond: Predicate) | Project(child: Logical, exprs: [Scalar]) @@ -51,11 +51,11 @@ data Logical with aggregates: [AggregateExpr] ) -data Physical with +data Physical = | Scan(table_name: String) | Filter(child: Physical, cond: Predicate) | Project(child: Physical, exprs: [Scalar]) - | Join with + | Join = | HashJoin( build_side: Physical, probe_side: Physical, @@ -84,7 +84,7 @@ data Physical with order_by: [(Scalar, SortOrder)] ) -data JoinType with +data JoinType = | Inner | Left | Right @@ -96,19 +96,19 @@ fn (expr: Scalar) apply_children(f: Scalar => Scalar) = () fn (pred: Predicate) remap(map: {I64 : I64)}) = match predicate - | ColumnRef(idx) => ColumnRef(map(idx)) - \ _ => predicate -> apply_children(child => rewrite_column_refs(child, map)) + | ColumnRef(idx) -> ColumnRef(map(idx)) + \ _ -> predicate.apply_children(child -> rewrite_column_refs(child, map)) [rule] fn (expr: Logical) join_commute = match expr - \ Join(left, right, Inner, cond) -> + \ Join(left, right, Inner, cond) => let right_indices = 0.right.schema_len, left_indices = 0..left.schema_len, - remapping = left_indices.map(i => (i, i + right_len)) ++ - right_indices.map(i => (left_len + i, i)).to_map, + remapping = left_indices.map(i -> (i, i + right_len)) ++ + right_indices.map(i -> (left_len + i, i)).to_map, in Project( Join(right, left, Inner, cond.remap(remapping)), - right_indices.map(i => ColumnRef(i)).to_array + right_indices.map(i -> ColumnRef(i)).to_array ) \ No newline at end of file diff --git a/optd-dsl/src/engine/eval/binary.rs b/optd-dsl/src/engine/eval/binary.rs index 24b30a5..f3087a2 100644 --- a/optd-dsl/src/engine/eval/binary.rs +++ b/optd-dsl/src/engine/eval/binary.rs @@ -94,7 +94,7 @@ pub(super) fn eval_binary_op(left: Value, op: &BinOp, right: Value) -> Value { } // Any other combination of value types or operations is not supported - _ => panic!("Invalid binary operation"), + _ => panic!("Invalid binary operation: {:?} {:?} {:?}", left, op, right), } } diff --git a/optd-dsl/src/engine/eval/expr.rs b/optd-dsl/src/engine/eval/expr.rs index efedfa1..f7a5a44 100644 --- a/optd-dsl/src/engine/eval/expr.rs +++ b/optd-dsl/src/engine/eval/expr.rs @@ -665,4 +665,71 @@ mod tests { assert_eq!(values.len(), 1); assert!(matches!(&values[0].0, Literal(Int64(30)))); // 10 + 20 = 30 (since 10 < 20) } + + #[test] + fn test_recursive_list_sum() { + let context = Context::new(HashMap::new()); + + // Define a recursive sum function using pattern matching + // sum([]) = 0 + // sum([x .. xs]) = x + sum(xs) + let sum_function = Value(Function(Closure( + vec!["arr".to_string()], + Box::new(PatternMatch( + Box::new(Ref("arr".to_string())), + vec![ + // Base case: empty array returns 0 + MatchArm { + pattern: Pattern::EmptyArray, + expr: CoreVal(int_val(0)), + }, + // Recursive case: add head + sum(tail) + MatchArm { + pattern: Pattern::ArrayDecomp( + Box::new(Bind("head".to_string(), Box::new(Wildcard))), + Box::new(Bind("tail".to_string(), Box::new(Wildcard))), + ), + expr: Binary( + Box::new(Ref("head".to_string())), + BinOp::Add, + Box::new(Call( + Box::new(Ref("sum".to_string())), + vec![Ref("tail".to_string())], + )), + ), + }, + ], + )), + ))); + + // Bind the recursive function in the context + let mut test_context = context.clone(); + test_context.bind("sum".to_string(), sum_function); + + // Test arrays + let empty_array = Value(CoreData::Array(vec![])); + let array_123 = Value(CoreData::Array(vec![int_val(1), int_val(2), int_val(3)])); + let array_42 = Value(CoreData::Array(vec![int_val(42)])); + + // Test 1: Sum of empty array should be 0 + let call_empty = Call(Box::new(Ref("sum".to_string())), vec![CoreVal(empty_array)]); + + let result = collect_stream_values(call_empty.evaluate(test_context.clone())); + assert_eq!(result.len(), 1); + assert!(matches!(&result[0].0, Literal(Int64(n)) if *n == 0)); + + // Test 2: Sum of [1, 2, 3] should be 6 + let call_123 = Call(Box::new(Ref("sum".to_string())), vec![CoreVal(array_123)]); + + let result = collect_stream_values(call_123.evaluate(test_context.clone())); + assert_eq!(result.len(), 1); + assert!(matches!(&result[0].0, Literal(Int64(n)) if *n == 6)); + + // Test 3: Sum of [42] should be 42 + let call_42 = Call(Box::new(Ref("sum".to_string())), vec![CoreVal(array_42)]); + + let result = collect_stream_values(call_42.evaluate(test_context)); + assert_eq!(result.len(), 1); + assert!(matches!(&result[0].0, Literal(Int64(n)) if *n == 42)); + } } diff --git a/optd-dsl/src/engine/eval/match.rs b/optd-dsl/src/engine/eval/match.rs index 1ec6192..17411ff 100644 --- a/optd-dsl/src/engine/eval/match.rs +++ b/optd-dsl/src/engine/eval/match.rs @@ -101,6 +101,36 @@ async fn match_pattern(value: Value, pattern: Pattern, context: Context) -> Vec< _ => vec![], }, + // Empty list pattern: match if value is an empty array + (EmptyArray, CoreData::Array(arr)) if arr.is_empty() => vec![context], + + // List decomposition pattern: match first element and rest of the array + (ArrayDecomp(head_pattern, tail_pattern), CoreData::Array(arr)) => { + if arr.is_empty() { + return vec![]; + } + + // Split array into head and tail + let head = arr[0].clone(); + let tail = Value(CoreData::Array(arr[1..].to_vec())); + + // Match head against head pattern + let head_contexts = match_pattern(head, (**head_pattern).clone(), context).await; + if head_contexts.is_empty() { + return vec![]; + } + + // For each successful head match, try to match tail + let mut result_contexts = Vec::new(); + for head_ctx in head_contexts { + let tail_contexts = + match_pattern(tail.clone(), (**tail_pattern).clone(), head_ctx).await; + result_contexts.extend(tail_contexts); + } + + result_contexts + } + // Struct pattern: match name and recursively match fields (Struct(pat_name, field_patterns), CoreData::Struct(val_name, field_values)) => { if pat_name != val_name || field_patterns.len() != field_values.len() { diff --git a/optd-dsl/src/lexer/lex.rs b/optd-dsl/src/lexer/lex.rs index 13ee023..5f1b2d1 100644 --- a/optd-dsl/src/lexer/lex.rs +++ b/optd-dsl/src/lexer/lex.rs @@ -54,8 +54,6 @@ fn lexer() -> impl Parser, Error = Simple> ("false", Token::Bool(false)), ("Unit", Token::TUnit), ("data", Token::Data), - ("with", Token::With), - ("as", Token::As), ("in", Token::In), ("let", Token::Let), ("match", Token::Match), diff --git a/optd-dsl/src/lexer/tokens.rs b/optd-dsl/src/lexer/tokens.rs index 2100d84..a13d86c 100644 --- a/optd-dsl/src/lexer/tokens.rs +++ b/optd-dsl/src/lexer/tokens.rs @@ -12,8 +12,6 @@ pub enum Token { // Other keywords Fn, Data, - With, - As, In, Let, Match, @@ -85,8 +83,6 @@ impl std::fmt::Display for Token { // Other keywords Token::Fn => write!(f, "fn"), Token::Data => write!(f, "data"), - Token::With => write!(f, "with"), - Token::As => write!(f, "as"), Token::In => write!(f, "in"), Token::Let => write!(f, "let"), Token::Match => write!(f, "match"), diff --git a/optd-dsl/src/parser/adt.rs b/optd-dsl/src/parser/adt.rs index 30f4465..4182548 100644 --- a/optd-dsl/src/parser/adt.rs +++ b/optd-dsl/src/parser/adt.rs @@ -25,7 +25,7 @@ pub fn adt_parser() -> impl Parser, Error = Simple, Spanned), } /// Represents a single arm in a pattern match expression @@ -212,8 +216,6 @@ pub enum UnaryOp { pub enum PostfixOp { /// Function or method call with arguments Call(Vec>), - /// Function composition operator - Compose(Identifier), /// Member/field access Member(Identifier), } diff --git a/optd-dsl/src/parser/expr.rs b/optd-dsl/src/parser/expr.rs index 5f994a4..bd24e1c 100644 --- a/optd-dsl/src/parser/expr.rs +++ b/optd-dsl/src/parser/expr.rs @@ -151,7 +151,7 @@ pub fn expr_parser() -> impl Parser, Error = Simple impl Parser, Error = Simple name }) .map(PostfixOp::Member), - just(Token::SmallArrow) - .ignore_then(select! { Token::TermIdent(name) => name }) - .map(PostfixOp::Compose), )) .map_with_span(|op, span| (op, span)) .repeated(), @@ -308,7 +305,7 @@ pub fn expr_parser() -> impl Parser, Error = Simple { assert_eq!(a_member, e_member); } - (PostfixOp::Compose(a_id), PostfixOp::Compose(e_id)) => { - assert_eq!(a_id, e_id); - } _ => panic!( "Postfix operation mismatch: expected {:?}, got {:?}", e_op, a_op @@ -923,7 +917,7 @@ mod tests { fn test_nested_closures_and_calls() { // Test nested closures with function calls let (result, errors) = - parse_expr("(x: I64) => (y: F64) => func(x + 1, y * 2.5, z => z && true)"); + parse_expr("(x: I64) -> (y: F64) -> func(x + 1, y * 2.5, z -> z && true)"); assert!( result.is_some(), @@ -1211,7 +1205,7 @@ mod tests { } // Let with complex body - let (result, errors) = parse_expr("let f: I64 => I64 = (x) => x * x in f(10)"); + let (result, errors) = parse_expr("let f: I64 -> I64 = (x) -> x * x in f(10)"); assert!( result.is_some(), "Expected successful parse for let with complex body and type annotation" @@ -1356,7 +1350,7 @@ mod tests { } // Test compose operator - let (result, errors) = parse_expr("map(dat) -> filter"); + let (result, errors) = parse_expr("map(dat).filter"); assert!( result.is_some(), @@ -1365,7 +1359,7 @@ mod tests { assert!(errors.is_empty(), "Expected no errors for compose operator"); if let Some(expr) = result { - if let Expr::Postfix(inner, PostfixOp::Compose(name)) = &*expr.value { + if let Expr::Postfix(inner, PostfixOp::Member(name)) = &*expr.value { assert_eq!(name, "filter"); if let Expr::Postfix(func, PostfixOp::Call(args)) = &*inner.value { @@ -1381,7 +1375,7 @@ mod tests { } // Test chained compose operators - let (result, errors) = parse_expr("transform(input) -> map -> filter -> reduce"); + let (result, errors) = parse_expr("transform(input).map.filter.reduce"); assert!( result.is_some(), "Expected successful parse for chained compose operators" @@ -1392,13 +1386,13 @@ mod tests { ); if let Some(expr) = result { - if let Expr::Postfix(inner1, PostfixOp::Compose(name1)) = &*expr.value { + if let Expr::Postfix(inner1, PostfixOp::Member(name1)) = &*expr.value { assert_eq!(name1, "reduce"); - if let Expr::Postfix(inner2, PostfixOp::Compose(name2)) = &*inner1.value { + if let Expr::Postfix(inner2, PostfixOp::Member(name2)) = &*inner1.value { assert_eq!(name2, "filter"); - if let Expr::Postfix(inner3, PostfixOp::Compose(name3)) = &*inner2.value { + if let Expr::Postfix(inner3, PostfixOp::Member(name3)) = &*inner2.value { assert_eq!(name3, "map"); if let Expr::Postfix(func, PostfixOp::Call(args)) = &*inner3.value { @@ -1424,7 +1418,7 @@ mod tests { fn test_match_expressions() { // Simple match expression let (result, errors) = - parse_expr("match x | 1 => \"one\" | 2 => \"two\" \\ _ => \"other\""); + parse_expr("match x | 1 -> \"one\" | 2 -> \"two\" \\ _ -> \"other\""); assert!( result.is_some(), "Expected successful parse for simple match expression" @@ -1478,10 +1472,10 @@ mod tests { // Match with complex patterns and expressions let (result, errors) = parse_expr( - "match point | Point(x, y) => \"first quadrant\" \ - | Circle(r) => \"circle\" \ - | Rectangle(b: Stuff(_), h) => \"rectangle\" \ - \\ _ => \"unknown shape\"", + "match point | Point(x, y) -> \"first quadrant\" \ + | Circle(r) -> \"circle\" \ + | Rectangle(b: Stuff(_), h) -> \"rectangle\" \ + \\ _ -> \"unknown shape\"", ); assert!( result.is_some(), @@ -1555,12 +1549,12 @@ mod tests { #[test] fn test_crazy_composite_expression() { // This test creates an insanely complex nested expression with multiple features - let crazy_expr = "let create_calculator = (operation) => match operation \ - | \"add\" => (x, y) => x + y \ - | \"subtract\" => (x, y) => x - y \ - | \"multiply\" => (x, y) => x * y \ - | \"divide\" => (x, y) => if y == 0 then fail(\"Division by zero\") else x / y \ - \\ _ => (x, y) => -1, \ + let crazy_expr = "let create_calculator = (operation) -> match operation \ + | \"add\" -> (x, y) -> x + y \ + | \"subtract\" -> (x, y) -> x - y \ + | \"multiply\" -> (x, y) -> x * y \ + | \"divide\" -> (x, y) -> if y == 0 then fail(\"Division by zero\") else x / y \ + \\ _ -> (x, y) -> -1, \ calc = create_calculator(\"multiply\"), \ result = calc({\"key\": 6}.key, 7), \ in if result > 40 \ @@ -1610,12 +1604,12 @@ mod tests { // Test the chained version of the let expressions let chained_crazy_expr = "let \ - create_calculator = (operation) => match operation \ - | \"add\" => (x, y) => x + y \ - | \"subtract\" => (x, y) => x - y \ - | \"multiply\" => (x, y) => x * y \ - | \"divide\" => (x, y) => if y == 0 then fail(\"Division by zero\") else x / y \ - \\ _ => (x, y) => -1, \ + create_calculator = (operation) -> match operation \ + | \"add\" -> (x, y) -> x + y \ + | \"subtract\" -> (x, y) -> x - y \ + | \"multiply\" -> (x, y) -> x * y \ + | \"divide\" -> (x, y) -> if y == 0 then fail(\"Division by zero\") else x / y \ + \\ _ -> (x, y) -> -1, \ calc = create_calculator(\"multiply\"), \ result: I64 = calc({\"key\": 6}.key, 7) \ in if result > 40 \ @@ -1664,4 +1658,156 @@ mod tests { } } } + + #[test] + fn test_list_decomposition_match() { + // Complex match expression with list decomposition patterns + let (result, errors) = parse_expr( + "match numbers \ + | [] -> \"empty list\" \ + | [x .. []] -> \"list with one element: \" ++ x.to_string() \ + | [x .. [y .. []]] -> \"list with two elements: \" ++ x.to_string() ++ \", \" ++ y.to_string() \ + | [head .. [second .. tail]] -> { \ + let sum = head + second, \ + rest_count = tail.length() \ + in \"list with \" ++ (rest_count + 2).to_string() ++ \" elements, first two sum: \" ++ sum.to_string() \ + } \ + \\ _ -> \"not a list\"", + ); + + assert!( + result.is_some(), + "Expected successful parse for match with list decomposition patterns" + ); + assert!( + errors.is_empty(), + "Expected no errors for match with list decomposition patterns" + ); + + if let Some(expr) = result { + if let Expr::PatternMatch(scrutinee, arms) = &*expr.value { + assert_expr_eq(&scrutinee.value, &Expr::Ref("numbers".to_string())); + assert_eq!(arms.len(), 5, "Expected 5 match arms"); + + // First arm: [] -> "empty list" + if let Pattern::EmptyArray = *arms[0].value.pattern.value { + // Empty list pattern is correct + } else { + panic!( + "Expected EmptyList pattern in first arm, got {:?}", + arms[0].value.pattern.value + ); + } + assert_expr_eq( + &arms[0].value.expr.value, + &Expr::Literal(Literal::String("empty list".to_string())), + ); + + // Second arm: [x .. []] -> "list with one element: " ++ x.to_string() + if let Pattern::ArrayDecomp(head, tail) = &*arms[1].value.pattern.value { + // Check head is a binding pattern for 'x' + if let Pattern::Bind(name, _) = &*head.value { + assert_eq!(*name.value, "x", "Expected head binding named 'x'"); + } else { + panic!("Expected binding pattern for head in second arm"); + } + + // Check tail is an empty list + if let Pattern::EmptyArray = *tail.value { + // Empty list is correct + } else { + panic!("Expected EmptyList pattern for tail in second arm"); + } + } else { + panic!("Expected ListDecomposition pattern in second arm"); + } + + // Third arm: [x .. [y .. []]] -> "list with two elements: " ++ x.to_string() ++ ", " ++ y.to_string() + if let Pattern::ArrayDecomp(outer_head, outer_tail) = &*arms[2].value.pattern.value + { + // Check outer_head is a binding pattern for 'x' + if let Pattern::Bind(name, _) = &*outer_head.value { + assert_eq!(*name.value, "x", "Expected outer head binding named 'x'"); + } else { + panic!("Expected binding pattern for outer head in third arm"); + } + + // Check outer_tail is another list decomposition + if let Pattern::ArrayDecomp(inner_head, inner_tail) = &*outer_tail.value { + // Check inner_head is a binding pattern for 'y' + if let Pattern::Bind(name, _) = &*inner_head.value { + assert_eq!(*name.value, "y", "Expected inner head binding named 'y'"); + } else { + panic!("Expected binding pattern for inner head in third arm"); + } + + // Check inner_tail is an empty list + if let Pattern::EmptyArray = *inner_tail.value { + // Empty list is correct + } else { + panic!("Expected EmptyList pattern for inner tail in third arm"); + } + } else { + panic!("Expected nested ListDecomposition pattern in third arm"); + } + } else { + panic!("Expected ListDecomposition pattern in third arm"); + } + + // Fourth arm: [head .. [second .. tail]] -> { complex expression } + if let Pattern::ArrayDecomp(outer_head, outer_tail) = &*arms[3].value.pattern.value + { + // Check outer_head is a binding pattern for 'head' + if let Pattern::Bind(name, _) = &*outer_head.value { + assert_eq!( + *name.value, "head", + "Expected outer head binding named 'head'" + ); + } else { + panic!("Expected binding pattern for outer head in fourth arm"); + } + + // Check outer_tail is another list decomposition + if let Pattern::ArrayDecomp(inner_head, inner_tail) = &*outer_tail.value { + // Check inner_head is a binding pattern for 'second' + if let Pattern::Bind(name, _) = &*inner_head.value { + assert_eq!( + *name.value, "second", + "Expected inner head binding named 'second'" + ); + } else { + panic!("Expected binding pattern for inner head in fourth arm"); + } + + // Check inner_tail is a binding pattern for 'tail' + if let Pattern::Bind(name, _) = &*inner_tail.value { + assert_eq!( + *name.value, "tail", + "Expected inner tail binding named 'tail'" + ); + } else { + panic!("Expected binding pattern for inner tail in fourth arm"); + } + } else { + panic!("Expected nested ListDecomposition pattern in fourth arm"); + } + } else { + panic!("Expected ListDecomposition pattern in fourth arm"); + } + + // Fifth arm: _ -> "not a list" + if let Pattern::Wildcard = *arms[4].value.pattern.value { + // Wildcard pattern is correct + } else { + panic!("Expected Wildcard pattern in fifth arm"); + } + assert_expr_eq( + &arms[4].value.expr.value, + &Expr::Literal(Literal::String("not a list".to_string())), + ); + } else { + panic!("Expected pattern match expression"); + } + } + } } diff --git a/optd-dsl/src/parser/function.rs b/optd-dsl/src/parser/function.rs index 2f23dfb..4f927c1 100644 --- a/optd-dsl/src/parser/function.rs +++ b/optd-dsl/src/parser/function.rs @@ -405,7 +405,7 @@ mod tests { #[test] fn test_function_with_complex_return_type() { - let input = "fn process(dat: [I64]): (I64) => {String: [Bool]} = (x) => {\"result\": [true, false]}"; + let input = "fn process(dat: [I64]): (I64) -> {String: [Bool]} = (x) -> {\"result\": [true, false]}"; let (result, errors) = parse_function(input); assert!(result.is_some(), "Expected successful parse"); @@ -448,7 +448,7 @@ mod tests { #[test] fn test_extern_function_with_complex_return_type() { - let input = "fn nativeProcess(dat: [I64]): (I64) => {String: [Bool]}"; + let input = "fn nativeProcess(dat: [I64]): I64 -> {String: [Bool]}"; let (result, errors) = parse_function(input); assert!(result.is_some(), "Expected successful parse"); diff --git a/optd-dsl/src/parser/module.rs b/optd-dsl/src/parser/module.rs index e97a926..06097c9 100644 --- a/optd-dsl/src/parser/module.rs +++ b/optd-dsl/src/parser/module.rs @@ -57,19 +57,19 @@ mod tests { let source = r#" data LogicalProps(schema_len: I64) - data Scalar with + data Scalar = | ColumnRef(idx: Int64) - | Literal with + | Literal = | IntLiteral(value: Int64) | StringLiteral(value: String) | BoolLiteral(value: Bool) \ NullLiteral - | Arithmetic with + | Arithmetic = | Mult(left: Scalar, right: Scalar) | Add(left: Scalar, right: Scalar) | Sub(left: Scalar, right: Scalar) \ Div(left: Scalar, right: Scalar) - | Predicate with + | Predicate = | And(children: [Predicate]) | Or(children: [Predicate]) | Not(child: Predicate) @@ -81,18 +81,18 @@ mod tests { | GreaterThanEqual(left: Scalar, right: Scalar) | IsNull(expr: Scalar) \ IsNotNull(expr: Scalar) - | Function with + | Function = | Cast(expr: Scalar, target_type: String) | Substring(str: Scalar, start: Scalar, length: Scalar) \ Concat(args: [Scalar]) - \ AggregateExpr with + \ AggregateExpr = | Sum(expr: Scalar) | Count(expr: Scalar) | Min(expr: Scalar) | Max(expr: Scalar) \ Avg(expr: Scalar) - data Logical with + data Logical = | Scan(table_name: String) | Filter(child: Logical, cond: Predicate) | Project(child: Logical, exprs: [Scalar]) @@ -108,11 +108,11 @@ mod tests { aggregates: [AggregateExpr] ) - data Physical with + data Physical = | Scan(table_name: String) | Filter(child: Physical, cond: Predicate) | Project(child: Physical, exprs: [Scalar]) - | Join with + | Join = | HashJoin( build_side: Physical, probe_side: Physical, @@ -141,7 +141,7 @@ mod tests { order_by: [(Scalar, SortOrder)] ) - data JoinType with + data JoinType = | Inner | Left | Right @@ -149,25 +149,25 @@ mod tests { \ Semi [rust] - fn (expr: Scalar) apply_children(f: Scalar => Scalar) = () + fn (expr: Scalar) apply_children(f: Scalar -> Scalar) = () fn (pred: Predicate) remap(map: {I64 : I64}) = match predicate - | ColumnRef(idx) => ColumnRef(map(idx)) - \ _ => predicate -> apply_children(child => rewrite_column_refs(child, map)) + | ColumnRef(idx) -> ColumnRef(map(idx)) + \ _ -> predicate.apply_children(child.rewrite_column_refs(child, map)) [rule] fn (expr: Logical) join_commute = match expr - \ Join(left, right, Inner, cond) => + \ Join(left, right, Inner, cond) -> let right_indices = 0..right.schema_len, left_indices = 0..left.schema_len, - remapping = left_indices.map(i => (i, i + right_len)) ++ - right_indices.map(i => (left_len + i, i)).to_map, + remapping = left_indices.map(i -> (i, i + right_len)) ++ + right_indices.map(i -> (left_len + i, i)).to_map, in Project( Join(right, left, Inner, cond.remap(remapping)), - right_indices.map(i => ColumnRef(i)).to_array + right_indices.map(i -> ColumnRef(i)).to_array ) "#; diff --git a/optd-dsl/src/parser/pattern.rs b/optd-dsl/src/parser/pattern.rs index 8550c52..1aec5fa 100644 --- a/optd-dsl/src/parser/pattern.rs +++ b/optd-dsl/src/parser/pattern.rs @@ -71,6 +71,19 @@ pub fn pattern_parser() -> impl Parser, Error = Simple impl Parser, Error = Simple {} + _ => panic!("Expected EmptyList pattern, got {:?}", pattern.value), + } + } + } + + #[test] + fn test_list_decomposition_pattern() { + let (result, errors) = parse_pattern("[x .. xs]"); + + assert!( + result.is_some(), + "Expected successful parse for list decomposition" + ); + assert!( + errors.is_empty(), + "Expected no errors for list decomposition" + ); + + if let Some(pattern) = result { + match *pattern.value { + Pattern::ArrayDecomp(head, tail) => { + // Check that head is a pattern binding 'x' + match *head.value { + Pattern::Bind(name, _) => { + assert_eq!(*name.value, "x", "Expected head binding named 'x'"); + } + _ => panic!("Expected head to be a binding pattern"), + } + + // Check that tail is a pattern binding 'xs' + match *tail.value { + Pattern::Bind(name, _) => { + assert_eq!(*name.value, "xs", "Expected tail binding named 'xs'"); + } + _ => panic!("Expected tail to be a binding pattern"), + } + } + _ => panic!("Expected ListDecomposition pattern"), + } + } + } + + #[test] + fn test_list_decomposition_with_empty_tail() { + let (result, errors) = parse_pattern("[x .. []]"); + + assert!( + result.is_some(), + "Expected successful parse for decomposition with empty tail" + ); + assert!( + errors.is_empty(), + "Expected no errors for decomposition with empty tail" + ); + + if let Some(pattern) = result { + match *pattern.value { + Pattern::ArrayDecomp(head, tail) => { + // Check that head is a pattern binding 'x' + match *head.value { + Pattern::Bind(name, _) => { + assert_eq!(*name.value, "x", "Expected head binding named 'x'"); + } + _ => panic!("Expected head to be a binding pattern"), + } + + // Check that tail is an empty list pattern + match *tail.value { + Pattern::EmptyArray => {} + _ => panic!("Expected tail to be an empty list pattern"), + } + } + _ => panic!("Expected ListDecomposition pattern"), + } + } + } } diff --git a/optd-dsl/src/parser/type.rs b/optd-dsl/src/parser/type.rs index e1762eb..7558735 100644 --- a/optd-dsl/src/parser/type.rs +++ b/optd-dsl/src/parser/type.rs @@ -84,7 +84,7 @@ pub fn type_parser() -> impl Parser, Error = Simple String").unwrap(); + let result = parse_type("(I64) -> String").unwrap(); assert!(matches!(*result.value, Type::Closure(param, ret) if matches!(*param.value, Type::Int64) && matches!(*ret.value, Type::String) )); - let result = parse_type("I64 => String").unwrap(); + let result = parse_type("I64 -> String").unwrap(); assert!(matches!(*result.value, Type::Closure(param, ret) if matches!(*param.value, Type::Int64) @@ -198,7 +198,7 @@ mod tests { )); // Is right-associative - let result = parse_type("I64 => String => String").unwrap(); + let result = parse_type("I64 -> String -> String").unwrap(); assert!(matches!(*result.value, Type::Closure(param, ret) if matches!(*param.value, Type::Int64) @@ -213,7 +213,7 @@ mod tests { #[test] fn test_complex_type() { // Test mix of Map, Array, Tuple, and Closure - let insane_type = "{String : [((I64, [{String : Physical}]) => [(AdtType, LogicalProps, (Bool => [Scalar]))])]}"; + let insane_type = "{String : [((I64, [{String : Physical}]) -> [(AdtType, LogicalProps, (Bool -> [Scalar]))])]}"; let result = parse_type(insane_type).unwrap(); let map_type = result.value; @@ -262,7 +262,7 @@ mod tests { // Test an even more complex nested type let even_more_insane = - "{String : {I64 : [(Logical => {String : [((Bool, [Scalar]) => Physical)]})]}}"; + "{String : {I64 : [(Logical -> {String : [((Bool, [Scalar]) -> Physical)]})]}}"; assert!(parse_type(even_more_insane).is_ok()); assert!(parse_type(even_more_insane).is_ok()); }