Skip to content

Commit

Permalink
Add callback-based e-class substitution function
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusrossel committed Apr 4, 2024
1 parent ed90e63 commit 64bfc86
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 12 deletions.
1 change: 1 addition & 0 deletions Rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ crate-type = ["staticlib"]

[dependencies]
egg = "0.9.5"
indexmap = "1.8.1"
libc = "0.2"
8 changes: 8 additions & 0 deletions Rust/src/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ pub struct LeanAnalysisData {
pub loose_bvars: HashSet<u64>, // A bvar is in this set only iff it is referenced by *some* e-node in the e-class.
}

impl LeanAnalysisData {

// TODO: Replace `loose_bvars` with `max_loose_bvar` if eta doesn't require more precision.
pub fn max_loose_bvar(&self) -> Option<u64> {
self.loose_bvars.iter().max().copied()
}
}

#[derive(Default)]
pub struct LeanAnalysis;
impl Analysis<LeanExpr> for LeanAnalysis {
Expand Down
34 changes: 22 additions & 12 deletions Rust/src/lean_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,39 @@ define_language! {
}
}

pub fn is_binder(expr: &LeanExpr) -> bool {
match expr {
LeanExpr::Lam(_) | LeanExpr::Forall(_) => true,
_ => false
impl LeanExpr {

pub fn bvar_idx(&self) -> Option<&Id> {
match self {
LeanExpr::BVar(idx) => Some(idx),
_ => None
}
}

// An expression is considered recursive if it can be part of a loop in an e-graph.
// Note that this is a result of the semantics of each constructor, not of its syntactic form.
pub fn is_rec(&self) -> bool {
match self {
LeanExpr::App(_) | LeanExpr::Lam(_) | LeanExpr::Forall(_) => true,
_ => false
}
}
}

// An expression is considered non-recursive if it can never be part of a loop in an e-graph.
// Note that this is a result of the semantics of each constructor, not of its syntactic form.
pub fn is_nonrec(expr: &LeanExpr) -> bool {
pub fn is_binder(expr: &LeanExpr) -> bool {
match expr {
LeanExpr::App(_) | LeanExpr::Lam(_) | LeanExpr::Forall(_) => false,
_ => true
LeanExpr::Lam(_) | LeanExpr::Forall(_) => true,
_ => false
}
}

// An expression `lhs` is smaller than another `rhs` wrt. non-recursiveness if `lhs` is not
// recursive but `rhs` is. If both are either recursive or non-recursive, the total order
// derived by `define_language!` applies.
pub fn nonrec_cmp(lhs: &LeanExpr, rhs: &LeanExpr) -> Ordering {
match (is_nonrec(lhs), is_nonrec(rhs)) {
(true, false) => Ordering::Less,
(false, true) => Ordering::Greater,
match (lhs.is_rec(), rhs.is_rec()) {
(false, true) => Ordering::Less,
(true, false) => Ordering::Greater,
_ => lhs.cmp(rhs),
}
}
Expand Down
1 change: 1 addition & 0 deletions Rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod nat_lit;
mod replace_bvars;
mod result;
mod rewrite;
mod subst;
mod trace;
mod util;

Expand Down
242 changes: 242 additions & 0 deletions Rust/src/subst.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
use egg::*;
use std::collections::HashMap;
use std::collections::HashSet;
use indexmap::IndexSet;
use crate::lean_expr::*;
use crate::analysis::*;
use crate::trace::*;

// Binder depth.
type Depth := u64;

// Identifiers of e-classes already present in an e-graph prior to any substitution.
type SrcId := Id;

// Identifiers of substitute e-classes. That is, those e-classes which are created
// during substitution.
type SubId := Id;

// An index of an e-node in `Context.todo`.
type TodoIdx := usize;

// A "target" is an e-class at a specific binder depth. These two values need to be
// considered together so often during substitution that we given them an own type.
#[derive(PartialEq, Eq, Hash, Clone)]
struct Target {
class: SrcId,
depth: Depth
}

// Various data required for bookkeeping during substitution.
#[derive(Default)]
struct Context {
// When an e-class `c` is visited at a binder depth `d` and continues on to visit child e-classes,
// this map records that `c` is work-in-progress for depth `d`. Note that a given e-class can be
// work-in-progress for multiple binder depths at once.
wip: HashMap<SrcId, HashSet<Depth>>,
// The substitution map maps a given e-class target to a new substituted e-class.
sub: HashMap<Target, SubId>,
// The set of unions that need to be performed after the main DFS of subsitution has completed.
// For each `unions.get(c) = Some cs`, the unions need to be performed for all e-classes in `{c} ∪ cs`.
unions: HashMap<SubId, HashSet<SubId>>,
// The set of e-nodes which are waiting for other e-classes to be subsituted before being able to
// be substituted themselves. An e-node appears in this set, when it tried to substitute a child
// e-class which is work-in-progress (see `Context.wip`). That is, when there's a loop in the e-graph.
// The e-node is stored together with the e-class target to which it belongs.
// A todo e-node is then processed when all of its dependencies are registered in `Context.sub`.
// This is in part tracked by `Contet.wait` and `Context.deps`.
todo: IndexSet<(LeanExpr, Target)>,
// The number of child e-classes a given todo e-node is waiting on. When this becomes 0, the e-node
// is processed.
wait: HashMap<TodoIdx, usize>,
// A map from e-class targets to those todo e-nodes waiting on the given target. When a given target
// is registered in `Context.sub`, the `Context.wait` of all dependents is updated.
deps: HashMap<Target, HashSet<TodoIdx>>
}

// Creates a new e-class which is the same as the given `class` but substitutes all loose bound variables
// in its sub-graph according to the given substitution function.
fn subst<B>(class: SrcId, graph: &mut LeanEGraph, reason: Symbol, bvar_subst: &B) -> SubId
where B : Fn(u64, u64, &mut LeanEGraph) -> LeanExpr {
let tgt = Target { class, depth: 0 };
let mut ctx: Context = Default::default();
let s = subst_core(&tgt, &mut ctx, graph, bvar_subst).unwrap();
perform_unions(ctx.unions, reason, graph);
return s
}

// TODO: This might be the exact function in which to control justification propagation.
fn perform_unions(unions: HashMap<SubId, HashSet<SubId>>, reason: Symbol, graph: &mut LeanEGraph) {
for (class, equivs) in unions.iter() {
for other in equivs {
graph.union_trusted(*class, *other, reason);
}
}
}

// When this function returns `None`, that means that a substitution for the given
// e-class target could not yet be created.
fn subst_core<B>(tgt: &Target, ctx: &mut Context, graph: &mut LeanEGraph, bvar_subst: &B) -> Option<SubId>
where B : Fn(u64, u64, &mut LeanEGraph) -> LeanExpr {
if let Some(&s) = ctx.sub.get(tgt) {
// If the given e-class target has already been substituted,
// return the substitute immediately.
return Some(s)
} else if graph[tgt.class].data.max_loose_bvar() < Some(tgt.depth) {
// If the given e-class target does not contain any loose bound variables
// which are large enough to escape the outermost binder of the subsitution's
// root e-class, we can just keep it as is. We do not even insert this information
// into `ctx.sub`, as the only place where this can potentially be queried is in
// the conditional branch above, in which case we just immediately skip to this
// branch again.
return Some(tgt.class)
} else if ctx.wip.get(&tgt.class).is_some_and(|w| w.contains(&tgt.depth)) {
// If the e-class target is already WIP, we have reached a proper loop (one where
// the binder depth does not increase on each iteration) and return `None` to
// indicate this to the caller.
return None
} else {
// If none of the previous branches apply, we are visiting the given e-class target
// for the first time. Thus, we first mark it as WIP.
_ = ctx.wip.entry(tgt.class).or_insert(HashSet::new()).insert(tgt.depth);
subst_core_new_target(tgt, ctx, graph, bvar_subst)
}
}

// Implementation detail of `subst_core`.
fn subst_core_new_target<B>(tgt: &Target, ctx: &mut Context, graph: &mut LeanEGraph, bvar_subst: &B) -> Option<SubId>
where B : Fn(u64, u64, &mut LeanEGraph) -> LeanExpr {
// Gets and sorts the nodes we are going to visit by `nonrec_cmp`. Moving non-recursive
// e-nodes to the front is simply an optimization as this means that we tend to visit
// leaves first which reduces the number of todo e-nodes and corresponding callbacks.
let mut nodes = graph[tgt.class].nodes.clone();
nodes.sort_by(|lhs, rhs| nonrec_cmp(lhs, rhs));

for node in nodes {
if let Some(bvar_idx) = node.bvar_idx() {
let idx_val = graph[*bvar_idx].data.nat_val.unwrap();
let node_sub = bvar_subst(idx_val, tgt.depth, graph);
add_subst_node(node_sub, tgt, ctx, graph);
} else if node.is_rec() {
subst_recursive_node(&node, tgt, ctx, graph, bvar_subst);
} else {
add_subst_node(node, tgt, ctx, graph);
}
}

// If all e-nodes remain todos, this returns `None`, indicating to the caller that this e-class
// target remains WIP. Otherwise, a substitute e-class for `tgt` must have been created and
// registered in `ctx.sub`, which we thus return.
ctx.sub.get(tgt).copied()
}

// Implementation detail of `subst_core_new_target`.
//
// Tries to construct the substitution of a given e-node which is expected to be recursive.
// If this is successful, the substitute is added to the substitute e-class in `ctx.sub`.
// If it fails, the e-node is registered as a todo node.
fn subst_recursive_node<B>(rec_node: &LeanExpr, tgt: &Target, ctx: &mut Context, graph: &mut LeanEGraph, bvar_subst: &B)
where B : Fn(u64, u64, &mut LeanEGraph) -> LeanExpr {
let mut sub_node = rec_node.clone();
let mut pending = HashSet::<Target>::new();

for (idx, child) in sub_node.children_mut().iter_mut().enumerate() {
// The depth is increased by 1 if the child is the body of a binder.
let depth = if idx == 1 && is_binder(rec_node) { tgt.depth + 1 } else { tgt.depth };
let child_tgt = Target { class: *child, depth: depth };
// If the substitution of the child works, replace the child with its substitute in `sub_node`.
// Otherwise, record the given child target as being pending.
if let Some(child_sub) = subst_core(&child_tgt, ctx, graph, bvar_subst) {
*child = child_sub;
} else {
pending.insert(child_tgt);
}
}

// If all children could be substituted, add the then completely substituted `sub_node` as
// a new node for `tgt`. Else, add the pending children as todos for `rec_node`.
if pending.is_empty() {
add_subst_node(sub_node, tgt, ctx, graph);
} else {
add_todo(rec_node, pending, tgt, ctx);
}
}

// Adds the given (substituted) e-node to the substitute e-class of the given e-class target.
// If that substitute e-class does not exist yet, it is created and the todo e-nodes are updated.
fn add_subst_node(node: LeanExpr, tgt: &Target, ctx: &mut Context, graph: &mut LeanEGraph) {
let node_class = graph.add(node);
if let Some(&s) = ctx.sub.get(tgt) {
// If the given e-class target already has a substitute, simply record the e-node's
// class as requiring a union with that substitute e-class.
ctx.unions.entry(s).or_insert(HashSet::new()).insert(node_class);
} else {
// If the given e-node is the first substitute e-node for the given e-class target,
// create the substitute e-class from it.
ctx.sub.insert(tgt.clone(), node_class);
// When a new substitute e-class is created, the todo e-nodes need to be updated.
update_todos(tgt, ctx, graph);
}
}

// Updates the todo e-nodes depending on a given newly substituted e-class target.
// This can might only reduce the e-nodes' waits, or, if the wait becomes 0, lead
// to the e-node being processed.
fn update_todos(tgt: &Target, ctx: &mut Context, graph: &mut LeanEGraph) {
// Get the set of todo e-nodes depending on the given e-class target.
// If this set is empty (absent), nothing needs to be done.
if let Some(deps) = ctx.deps.remove(tgt) {
for dep in deps {
// We assume any given todo e-node to have an associated wait.
let new_wait = ctx.wait.get(&dep).unwrap() - 1;
if new_wait == 0 {
// If the new wait is 0, then all of the todo e-node's children have
// substitute e-classes and the e-node can be processed.
ctx.wait.remove(&dep);
process_todo(dep, ctx, graph);
} else {
// If the new wait is still not 0, the todo e-node must continue
// to wait.
ctx.wait.insert(dep, new_wait);
}
}
}
}

// Performs substitution of the the given todo e-node under the assumption that
// `ctx.sub` contains substitutes for all dependencies.
fn process_todo(todo: TodoIdx, ctx: &mut Context, graph: &mut LeanEGraph) {
// We assume any given todo e-node to have an entry in the `todo` set.
let (node, tgt) = ctx.todo.get_index(todo).unwrap().clone();

// A todo can only be an application-, lambda- or forall-node.
let mut sub_node = node.clone();
for (idx, child) in sub_node.children_mut().iter_mut().enumerate() {
// The depth is increased by 1 if the child is the body of a binder.
let depth = if idx == 1 && is_binder(&node) { tgt.depth + 1 } else { tgt.depth };
let child_tgt = Target { class: *child, depth: depth };
// The substitutes of children of a todo node are expected to be present
// when this function (`process_todo`) is called.
*child = *ctx.sub.get(&child_tgt).unwrap();
}

add_subst_node(sub_node, &tgt, ctx, graph);
}

fn add_todo(node: &LeanExpr, deps: HashSet<Target>, tgt: &Target, ctx: &mut Context) {
// If the given node is already a todo for the given e-class target,
// then nothing more needs to be done.
let (todo, is_new) = ctx.todo.insert_full((node.clone(), tgt.clone()));
if !is_new { return }

// The number of dependencies that need to be waited on is exactly the number of
// elements in `deps`.
ctx.wait.insert(todo, deps.len());

// Add the todo to the dependency list of each of its dependencies. That way, when
// the dependencies are resolved, the todo is processed.
for dep in deps {
ctx.deps.entry(dep).or_insert(HashSet::new()).insert(todo);
}
}

0 comments on commit 64bfc86

Please sign in to comment.