Skip to content

Commit

Permalink
feat: DSL Array Pattern Match & Simplification (#36)
Browse files Browse the repository at this point in the history
# Add List Pattern Matching to the HIR and Evaluation Engine

## Summary

This PR adds list pattern matching capabilities to the language. It
introduces two new pattern variants: `EmptyArray` for matching empty
arrays, and `ArrayDecomp` for decomposing arrays into head and tail
components.

## Changes

- Added new pattern variants to the `Pattern` enum in the HIR (engine) &
AST (parsing).
- Extended the pattern matching logic in `match_pattern` to handle list
patterns
- Added comprehensive tests demonstrating list pattern matching
functionality
- Simplified syntax elements as agreed during the design phase

## Implementation Details

```rust
// New pattern variants
pub enum Pattern {
    // ... existing variants
    EmptyArray,
    ArrayDecomp(Box<Pattern>, Box<Pattern>),
}
```

The pattern matching logic was extended to handle:
- Empty arrays with the `EmptyArray` pattern
- Non-empty arrays with the `ArrayDecomp` pattern, which splits arrays
into head and tail

## Testing

Added a test that shows a practical application of list pattern
matching: a recursive sum function defined as:

```
fn sum(arr: [I64]): I64 =
  match arr 
    | [] -> 0
    \ [head .. tail] -> head + sum(tail)
```

The test verifies that this function correctly computes:
- `sum([]) = 0`
- `sum([42]) = 42`
- `sum([1, 2, 3]) = 6`

## Related Changes

As part of this PR, I also simplified some syntax elements:
- Changed the composition operator from `->` to `.`
- Changed the match arrow from `=>` to `->`
  • Loading branch information
AlSchlo authored Mar 1, 2025
1 parent 2c17bf7 commit 285fe5f
Show file tree
Hide file tree
Showing 14 changed files with 435 additions and 92 deletions.
2 changes: 2 additions & 0 deletions optd-dsl/src/analyzer/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ pub enum Pattern {
Struct(Identifier, Vec<Pattern>),
Operator(Operator<Pattern>),
Wildcard,
EmptyArray,
ArrayDecomp(Box<Pattern>, Box<Pattern>),
}

/// Match arm combining pattern and expression
Expand Down
32 changes: 16 additions & 16 deletions optd-dsl/src/cli/basic.op
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
data LogicalProps(schema_len: I64)

data Scalar with
data Scalar =
| ColumnRef(idx: Int64)
| Literal with
| Literal =
| IntLiteral(value: Int64)
| StringLiteral(value: String)
| BoolLiteral(value: Bool)
\ NullLiteral
| Arithmetic with
| Arithmetic =
| Mult(left: Scalar, right: Scalar)
| Add(left: Scalar, right: Scalar)
| Sub(left: Scalar, right: Scalar)
\ Div(left: Scalar, right: Scalar)
| Predicate with
| Predicate =
| And(children: [Predicate])
| Or(children: [Predicate])
| Not(child: Predicate)
Expand All @@ -24,18 +24,18 @@ data Scalar with
| GreaterThanEqual(left: Scalar, right: Scalar)
| IsNull(expr: Scalar)
\ IsNotNull(expr: Scalar)
| Function with
| Function =
| Cast(expr: Scalar, target_type: String)
| Substring(str: Scalar, start: Scalar, length: Scalar)
\ Concat(args: [Scalar])
\ AggregateExpr with
\ AggregateExpr =
| Sum(expr: Scalar)
| Count(expr: Scalar)
| Min(expr: Scalar)
| Max(expr: Scalar)
\ Avg(expr: Scalar)

data Logical with
data Logical =
| Scan(table_name: String)
| Filter(child: Logical, cond: Predicate)
| Project(child: Logical, exprs: [Scalar])
Expand All @@ -51,11 +51,11 @@ data Logical with
aggregates: [AggregateExpr]
)

data Physical with
data Physical =
| Scan(table_name: String)
| Filter(child: Physical, cond: Predicate)
| Project(child: Physical, exprs: [Scalar])
| Join with
| Join =
| HashJoin(
build_side: Physical,
probe_side: Physical,
Expand Down Expand Up @@ -84,7 +84,7 @@ data Physical with
order_by: [(Scalar, SortOrder)]
)

data JoinType with
data JoinType =
| Inner
| Left
| Right
Expand All @@ -96,19 +96,19 @@ fn (expr: Scalar) apply_children(f: Scalar => Scalar) = ()

fn (pred: Predicate) remap(map: {I64 : I64)}) =
match predicate
| ColumnRef(idx) => ColumnRef(map(idx))
\ _ => predicate -> apply_children(child => rewrite_column_refs(child, map))
| ColumnRef(idx) -> ColumnRef(map(idx))
\ _ -> predicate.apply_children(child -> rewrite_column_refs(child, map))

[rule]
fn (expr: Logical) join_commute = match expr
\ Join(left, right, Inner, cond) ->
\ Join(left, right, Inner, cond) =>
let
right_indices = 0.right.schema_len,
left_indices = 0..left.schema_len,
remapping = left_indices.map(i => (i, i + right_len)) ++
right_indices.map(i => (left_len + i, i)).to_map,
remapping = left_indices.map(i -> (i, i + right_len)) ++
right_indices.map(i -> (left_len + i, i)).to_map,
in
Project(
Join(right, left, Inner, cond.remap(remapping)),
right_indices.map(i => ColumnRef(i)).to_array
right_indices.map(i -> ColumnRef(i)).to_array
)
2 changes: 1 addition & 1 deletion optd-dsl/src/engine/eval/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ pub(super) fn eval_binary_op(left: Value, op: &BinOp, right: Value) -> Value {
}

// Any other combination of value types or operations is not supported
_ => panic!("Invalid binary operation"),
_ => panic!("Invalid binary operation: {:?} {:?} {:?}", left, op, right),
}
}

Expand Down
67 changes: 67 additions & 0 deletions optd-dsl/src/engine/eval/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,4 +665,71 @@ mod tests {
assert_eq!(values.len(), 1);
assert!(matches!(&values[0].0, Literal(Int64(30)))); // 10 + 20 = 30 (since 10 < 20)
}

#[test]
fn test_recursive_list_sum() {
let context = Context::new(HashMap::new());

// Define a recursive sum function using pattern matching
// sum([]) = 0
// sum([x .. xs]) = x + sum(xs)
let sum_function = Value(Function(Closure(
vec!["arr".to_string()],
Box::new(PatternMatch(
Box::new(Ref("arr".to_string())),
vec![
// Base case: empty array returns 0
MatchArm {
pattern: Pattern::EmptyArray,
expr: CoreVal(int_val(0)),
},
// Recursive case: add head + sum(tail)
MatchArm {
pattern: Pattern::ArrayDecomp(
Box::new(Bind("head".to_string(), Box::new(Wildcard))),
Box::new(Bind("tail".to_string(), Box::new(Wildcard))),
),
expr: Binary(
Box::new(Ref("head".to_string())),
BinOp::Add,
Box::new(Call(
Box::new(Ref("sum".to_string())),
vec![Ref("tail".to_string())],
)),
),
},
],
)),
)));

// Bind the recursive function in the context
let mut test_context = context.clone();
test_context.bind("sum".to_string(), sum_function);

// Test arrays
let empty_array = Value(CoreData::Array(vec![]));
let array_123 = Value(CoreData::Array(vec![int_val(1), int_val(2), int_val(3)]));
let array_42 = Value(CoreData::Array(vec![int_val(42)]));

// Test 1: Sum of empty array should be 0
let call_empty = Call(Box::new(Ref("sum".to_string())), vec![CoreVal(empty_array)]);

let result = collect_stream_values(call_empty.evaluate(test_context.clone()));
assert_eq!(result.len(), 1);
assert!(matches!(&result[0].0, Literal(Int64(n)) if *n == 0));

// Test 2: Sum of [1, 2, 3] should be 6
let call_123 = Call(Box::new(Ref("sum".to_string())), vec![CoreVal(array_123)]);

let result = collect_stream_values(call_123.evaluate(test_context.clone()));
assert_eq!(result.len(), 1);
assert!(matches!(&result[0].0, Literal(Int64(n)) if *n == 6));

// Test 3: Sum of [42] should be 42
let call_42 = Call(Box::new(Ref("sum".to_string())), vec![CoreVal(array_42)]);

let result = collect_stream_values(call_42.evaluate(test_context));
assert_eq!(result.len(), 1);
assert!(matches!(&result[0].0, Literal(Int64(n)) if *n == 42));
}
}
30 changes: 30 additions & 0 deletions optd-dsl/src/engine/eval/match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,36 @@ async fn match_pattern(value: Value, pattern: Pattern, context: Context) -> Vec<
_ => vec![],
},

// Empty list pattern: match if value is an empty array
(EmptyArray, CoreData::Array(arr)) if arr.is_empty() => vec![context],

// List decomposition pattern: match first element and rest of the array
(ArrayDecomp(head_pattern, tail_pattern), CoreData::Array(arr)) => {
if arr.is_empty() {
return vec![];
}

// Split array into head and tail
let head = arr[0].clone();
let tail = Value(CoreData::Array(arr[1..].to_vec()));

// Match head against head pattern
let head_contexts = match_pattern(head, (**head_pattern).clone(), context).await;
if head_contexts.is_empty() {
return vec![];
}

// For each successful head match, try to match tail
let mut result_contexts = Vec::new();
for head_ctx in head_contexts {
let tail_contexts =
match_pattern(tail.clone(), (**tail_pattern).clone(), head_ctx).await;
result_contexts.extend(tail_contexts);
}

result_contexts
}

// Struct pattern: match name and recursively match fields
(Struct(pat_name, field_patterns), CoreData::Struct(val_name, field_values)) => {
if pat_name != val_name || field_patterns.len() != field_values.len() {
Expand Down
2 changes: 0 additions & 2 deletions optd-dsl/src/lexer/lex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ fn lexer() -> impl Parser<char, Vec<(Token, Span)>, Error = Simple<char, Span>>
("false", Token::Bool(false)),
("Unit", Token::TUnit),
("data", Token::Data),
("with", Token::With),
("as", Token::As),
("in", Token::In),
("let", Token::Let),
("match", Token::Match),
Expand Down
4 changes: 0 additions & 4 deletions optd-dsl/src/lexer/tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ pub enum Token {
// Other keywords
Fn,
Data,
With,
As,
In,
Let,
Match,
Expand Down Expand Up @@ -85,8 +83,6 @@ impl std::fmt::Display for Token {
// Other keywords
Token::Fn => write!(f, "fn"),
Token::Data => write!(f, "data"),
Token::With => write!(f, "with"),
Token::As => write!(f, "as"),
Token::In => write!(f, "in"),
Token::Let => write!(f, "let"),
Token::Match => write!(f, "match"),
Expand Down
18 changes: 9 additions & 9 deletions optd-dsl/src/parser/adt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn adt_parser() -> impl Parser<Token, Spanned<Adt>, Error = Simple<Token, Sp
.map_with_span(Spanned::new);

let with_sum_parser = type_ident
.then_ignore(just(Token::With))
.then_ignore(just(Token::Eq))
.then(
just(Token::Vertical)
.ignore_then(inner_adt_parser.clone())
Expand Down Expand Up @@ -96,7 +96,7 @@ mod tests {

#[test]
fn test_enum_adt() {
let input = "data JoinType with
let input = "data JoinType =
| Inner
\\ Outer";
let (result, errors) = parse_adt(input);
Expand Down Expand Up @@ -132,7 +132,7 @@ mod tests {

#[test]
fn test_enum_with_struct_variants() {
let input = "data Shape with
let input = "data Shape =
| Circle(center: Point, radius: F64)
| Rectangle(topLeft: Point, width: F64, height: F64)
\\ Triangle(p1: Point, p2: Point, p3: Point)";
Expand Down Expand Up @@ -175,8 +175,8 @@ mod tests {

#[test]
fn test_nested_enum() {
let input = "data Expression with
| Literal with
let input = "data Expression =
| Literal =
| IntLiteral(value: I64)
| BoolLiteral(value: Bool)
\\ StringLiteral(value: String)
Expand Down Expand Up @@ -234,15 +234,15 @@ mod tests {

#[test]
fn test_double_nested_enum() {
let input = "data Menu with
| File with
| New with
let input = "data Menu =
| File =
| New =
| Document
| Project
\\ Template
| Open
\\ Save
| Edit with
| Edit =
| Cut
| Copy
\\ Paste
Expand Down
6 changes: 4 additions & 2 deletions optd-dsl/src/parser/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ pub enum Pattern {
Literal(Literal),
/// Wildcard pattern: matches any value
Wildcard,
/// Empty array pattern: matches an empty array
EmptyArray,
/// Array decomposition pattern: matches an array with head and rest elements
ArrayDecomp(Spanned<Pattern>, Spanned<Pattern>),
}

/// Represents a single arm in a pattern match expression
Expand Down Expand Up @@ -212,8 +216,6 @@ pub enum UnaryOp {
pub enum PostfixOp {
/// Function or method call with arguments
Call(Vec<Spanned<Expr>>),
/// Function composition operator
Compose(Identifier),
/// Member/field access
Member(Identifier),
}
Expand Down
Loading

0 comments on commit 285fe5f

Please sign in to comment.