Skip to content

Commit

Permalink
Extend the notion of invalid matches by disallowing matching of local…
Browse files Browse the repository at this point in the history
…ly non-loose bvars
  • Loading branch information
marcusrossel committed Apr 5, 2024
1 parent c11041d commit b2d3e98
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 77 deletions.
2 changes: 1 addition & 1 deletion C/ffi.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ extern egg_result c_egg_explain_congr(
// `gen_beta_rw`: boolean indicating whether egg should use an additional rewrite to perform beta-reduction
// `block_invalid_matches`: boolean indicating whether rewrites should be skipped if variables matched bvars in an invalid way
// `shift_captured_bvars`: boolean indicating whether rewrites should shift captured bvars to avoid invalid capturing
// `trace_substitutions`: boolean indicating whether calls to `replace_loose_bvars` should be traced
// `trace_substitutions`: boolean indicating whether calls to `subst` should be traced
// `trace_captured_bvar_shifting`: boolean indicating whether calls to `shifted_subst_for_pat` should be traced
// `viz_path`: string
// return value: string explaining the rewrite sequence
Expand Down
18 changes: 12 additions & 6 deletions Lean/Egg/Tests/BlockInvalidMatches.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,33 @@ import Egg
-- We have to disable β-reduction as part of normalization, as otherwise `thm` is useless.
set_option egg.betaReduceRws false

theorem thm : ∀ y x : Nat, (fun _ => (fun _ => x) x) y = x :=
theorem thm : ∀ y x : Nat, (fun _ => (fun _ => x) x) y = x :=
fun _ _ => rfl

-- In this example egg finds a proof, but we're not performing proof reconstruction (which would be
-- impossible) as a result of setting `exitPoint := some .beforeProof`.
set_option egg.blockInvalidMatches false in
example : False := by
have h : (fun x => (fun a => (fun a => a) a) 0) = (fun x => x) := by
egg (config := { exitPoint := some .beforeProof }) [thm]
egg (config := { exitPoint := some .beforeProof }) [thm]
have : (fun _ => 0) 1 = (fun x => x) 1 := by rw [h]
contradiction

-- This test covers Condition (2) of valid matches.
set_option egg.blockInvalidMatches true in
example : True := by
fail_if_success -- This fails because egg could not find a proof.
have _ : (fun x => (fun a => (fun a => a) a) 0) = (fun x => x) := by
have : (fun x => (fun a => (fun a => a) a) 0) = (fun x => x) := by
egg (config := { exitPoint := some .beforeProof }) [thm]
constructor

-- BUG: The rewrite can applied backwards in an invalid way as `shiftCapturedBVars` isn't enabled.
-- This theorem is only applicable in the backward direction.
theorem thm₂ : ∀ x y : Nat, x = (fun _ => x) y :=
fun _ _ => rfl

-- This test covers Condition (1) of valid matches.
set_option egg.blockInvalidMatches true in
example : True := by
have _ : (fun x => (fun a => (fun a => a) a) <| nat_lit 0) = (fun x => x) := by
egg (config := { exitPoint := some .beforeProof }) [thm (nat_lit 0)]
fail_if_success -- This fails because egg could not find a proof.
have : (fun x => x) = (fun _ : Nat => (fun x => x) 1) := by egg [thm₂]
constructor
38 changes: 28 additions & 10 deletions Lean/Egg/Tests/ShiftCapturedBVars.lean
Original file line number Diff line number Diff line change
@@ -1,25 +1,43 @@
import Egg

-- We have to disable β-reduction as part of normalization, as otherwise `thm` is useless.
-- We have to disable β-reduction as part of normalization, as otherwise `thm₁,₂` are useless, and
-- disable β-reduction in egg, as this interferes with the test cases.
set_option egg.betaReduceRws false
set_option egg.genBetaRw false

theorem thm : ∀ x : Nat, x = (fun _ => x) (nat_lit 0) :=
fun _ => rfl
-- This theorem is only applicable in the forward direction.
theorem thm₁ : ∀ x y : Nat, (x, y).fst = (fun _ => x) (nat_lit 1) :=
fun _ _ => rfl

-- In this example egg finds a proof, but we're not performing proof reconstruction (which would be
-- impossible) as a result of setting `exitPoint := some .beforeProof`.
set_option egg.shiftCapturedBVars false in
example : False := by
have h : (fun x => x) = (fun y : Nat => (fun x => x) 1) := by
egg (config := { exitPoint := some .beforeProof }) [thm]
have h : (fun x => (x, 5).fst) = (fun _ : Nat => (fun x => x) 1) := by
egg (config := { exitPoint := some .beforeProof }) [thm]
have : (fun x => x) 0 = (fun y => 1) 0 := by rw [h]
contradiction

-- BUG: This generates a completely broken e-graph.
-- For example, a number node becomes equivalent to an application node.
-- This seems to originate entirely from `replace_loose_bvars` though.
set_option egg.shiftCapturedBVars true in
example : True := by
have h : (fun x => x) = (fun y : Nat => (fun x => x) (nat_lit 1)) := by
sorry -- egg (config := { traceCapturedBVarShifting := true }) [thm]
fail_if_success -- This fails because egg could not find a proof.
have : (fun x => (x, 5).fst) = (fun _ : Nat => (fun x => x) 1) := by egg [thm₁]
constructor

theorem thm₂ : ∀ x : Nat, (fun _ => (fun _ => x) x) 0 = x :=
fun _ => rfl

-- TODO: This seems to cause an infinite loop or at least extremely long runtime in
-- `shifted_subst_for_pat` or `subst`. I think what is happening is that `thm₂` is applied in
-- the backward direction over and over again which quickly blows up the e-graph.
-- Investigate further what's happening by somehow tracing `shifted_subst_for_pat`.
set_option egg.shiftCapturedBVars true in
example : True := by
have : (fun x => (fun a => (fun a => a) a) 0) = (fun x => x) := by sorry -- egg [thm₂]
constructor

-- Unrelated to capture avoidance:
--
-- TODO: If we have a theorem like `(fun a b => a) x y = x`, it's only applicable in the forward
-- direction. But once we β-reduce it, it's applicable in both directions. I think that can
-- cause problems during reconstruction as we cannot reconstruct the assignment of `y`.
77 changes: 17 additions & 60 deletions Rust/src/bvar_capture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::str::FromStr;
use egg::*;
use crate::lean_expr::*;
use crate::analysis::*;
use crate::replace_bvars::*;
use crate::trace::*;
use crate::valid_match::*;

pub struct BVarCapture {
pub rhs: Pattern<LeanExpr>,
Expand All @@ -17,73 +17,23 @@ impl Applier<LeanExpr, LeanAnalysis> for BVarCapture {
fn apply_one(&self, egraph: &mut LeanEGraph, _: Id, subst: &Subst, searcher_ast: Option<&PatternAst<LeanExpr>>, rule: Symbol) -> Vec<Id> {
let searcher_ast = searcher_ast.unwrap();

// TODO: Is this cached once it is called?
// A substitution is safe if it does not map any variables to e-classes containing loose bvars.
let subst_is_safe = { |_ : ()|
self.rhs.vars().iter().all(|var| egraph[subst[*var]].data.loose_bvars.is_empty())
};

// Abort the rewrite if invalid matches are disallowed and the given match is invalid.
if self.block_invalid_matches && !subst_is_safe(()) && !match_is_valid(subst, searcher_ast, egraph) {
if self.block_invalid_matches && !match_is_valid(subst, searcher_ast, egraph) {
return vec![]
}

if self.shift_captured_bvars && !subst_is_safe(()) {
// A substitution needs no shifting if it does not map any variables to e-classes containing loose bvars.
let needs_no_shift = self.rhs.vars().iter().all(|var| egraph[subst[*var]].data.loose_bvars.is_empty());
if !self.shift_captured_bvars || needs_no_shift {
// Following https://docs.rs/egg/latest/src/egg/pattern.rs.html#373
let (from, did_union) = egraph.union_instantiations(searcher_ast, &self.rhs.ast, subst, rule);
if did_union { vec![from] } else { vec![] }
} else {
dbg_trace(format!("Start capture avoidance for\n LHS: {}\n RHS: {}\n RHS Raw: {:?}\n subst: {:?}", searcher_ast, self.rhs, self.rhs.ast.as_ref(), subst), TraceGroup::Capture);
let (shifted_subst, shifted_rhs) = shifted_subst_for_pat(subst, &self.rhs, egraph);
dbg_trace("End capture avoidance\n", TraceGroup::Capture);
let (from, did_union) = egraph.union_instantiations(searcher_ast, &shifted_rhs, &shifted_subst, rule);
if did_union { vec![from] } else { vec![] }
} else {
// Following https://docs.rs/egg/latest/src/egg/pattern.rs.html#373
let (from, did_union) = egraph.union_instantiations(searcher_ast, &self.rhs.ast, subst, rule);
if did_union { vec![from] } else { vec![] }
}
}
}

// A match (a substitution and pattern) is valid, if for each variable v in the substitution
// which maps to an e-class with loose bvars, v only appears under the same binder.
//
// Example of an invalid match:
// Pattern term `(lam _ (lam _, ?x) ?x)` matching against `(lam _ (lam _, (bvar 0)) (bvar 0))`.
fn match_is_valid(subst: &Subst, pat: &PatternAst<LeanExpr>, egraph: &LeanEGraph) -> bool {
let last = pat.as_ref().len() - 1;
match_is_valid_aux(last, vec![], None, subst, pat, egraph, &mut HashMap::new())
}

type ExprPos = Vec<usize>;
// A binder position of `None` indicates that the associated value does not appear under a binder.
type BinderPos = Option<ExprPos>;

fn match_is_valid_aux(idx: usize, pos: ExprPos, parent_binder: BinderPos, subst: &Subst, pat: &PatternAst<LeanExpr>, egraph: &LeanEGraph, parent_binders: &mut HashMap<Var, BinderPos>) -> bool {
match &pat.as_ref()[idx] {
ENodeOrVar::Var(var) => {
if egraph[subst[*var]].data.loose_bvars.is_empty() {
// If the given variable does not map to an e-class containing loose bvars, if cannot cause any problems.
true
} else if let Some(required_parent) = parent_binders.get(var) {
// If the given variable has already occured elsewhere in the pattern, the parent binder of that occurrence
// must be the same as the current parent binder.
parent_binder == *required_parent
} else {
// If the given variable has not been visited yet, record its parent binder.
parent_binders.insert(*var, parent_binder);
true
}
},
ENodeOrVar::ENode(e) => {
for (i, child) in e.children().iter().enumerate() {
// If `e` is a binder, set the `parent_binder` for its body.
let child_parent_binder = if is_binder(&e) && i == 1 { Some(pos.clone()) } else { parent_binder.clone() };
let child_idx = usize::from(*child);
let mut child_pos = pos.clone();
child_pos.push(i);
if !match_is_valid_aux(child_idx, child_pos, child_parent_binder, subst, pat, egraph, parent_binders) {
return false
}
}
true
}
}
}
Expand Down Expand Up @@ -161,7 +111,7 @@ fn shifted_subst_for_pat_aux(
let fresh_var = make_fresh_var();
let new_idx = shifted_pat.add(ENodeOrVar::Var(fresh_var));
pat_node_indices.insert(ENodeOrVar::Var(fresh_var), new_idx);
let sub = replace_loose_bvars(&shift_up(binder_depth), target_class, egraph, Symbol::from("λ↕"), &mut ());
let sub = crate::subst::subst(target_class, egraph, Symbol::from("λ↕"), &shift_up(binder_depth));
dbg_trace(format!("var is being replaced by fresh var {} with shifted class {}", fresh_var, sub), TraceGroup::Capture);
subst.insert(fresh_var, sub);
cache.insert((binder_depth, target_class), fresh_var);
Expand Down Expand Up @@ -192,6 +142,13 @@ fn shifted_subst_for_pat_aux(
}
}

fn shift_up(offset: u64) -> impl Fn(u64, u64, &mut LeanEGraph) -> LeanExpr {
move |idx, binder_depth, egraph| {
if idx < binder_depth { unreachable!() } // `subst` provides the invariant that `idx >= binder_depth`.
LeanExpr::BVar(egraph.add(LeanExpr::Nat(idx + offset)))
}
}

fn make_fresh_var() -> Var {
use std::sync::atomic::*;
static COUNTER: AtomicUsize = AtomicUsize::new(0);
Expand Down
1 change: 1 addition & 0 deletions Rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod rewrite;
mod subst;
mod trace;
mod util;
mod valid_match;

#[repr(C)]
#[derive(PartialEq)]
Expand Down
93 changes: 93 additions & 0 deletions Rust/src/valid_match.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@

use std::collections::HashMap;
use egg::*;
use crate::lean_expr::*;
use crate::analysis::*;

// An expression position is a sequence of coordinates which describe how to traverse nodes
// starting at a given root node. Each coordinate dictates which child of a node to visit.
// In an expression tree, each node has a unique position.
type ExprPos = Vec<usize>;
// A binder position of `None` indicates that the associated value does not appear under a binder.
type BinderPos = Option<ExprPos>;

// The given values are always used together in `match_is_valid_core`, so we given them a type.
struct Location {
pos: ExprPos,
parent_binder: BinderPos,
binder_depth: u64
}

// Various data required for checking match validity.
struct Context<'a> {
pat: &'a PatternAst<LeanExpr>,
subst: &'a Subst,
graph: &'a LeanEGraph,
parent_binders: HashMap<Var, BinderPos>
}

// A match (a substitution and pattern) is valid if both:
// (1) No (non-loose) bound variables are matched.
// (2) For each variable v in the substitution which maps to an e-class with loose bvars,
// v only appears under the same binder.
//
// Example of invalid matches:
// (1) Pattern term `(lam _, ?x)` matching against `(lam _, (bvar 0))`.
// (2) Pattern term `(lam _ (lam _, ?x) ?x)` matching against `(lam _ (lam _, (bvar 2)) (bvar 2))`.
//
// The need for Condition (2) should be obvious. Condition (1) on the other hand only follows from
// the fact that our rewrites come from theorems of the form `forall x, y = z`. Thus, if a pattern
// variable appears, it can never refer to a (non-loose) bound variable, as `x` cannot refer to any
// bound variables in `y` or `z`.
pub fn match_is_valid(subst: &Subst, pat: &PatternAst<LeanExpr>, graph: &LeanEGraph) -> bool {
let mut ctx = Context { pat, subst, graph, parent_binders: HashMap::new() };
let root_idx = pat.as_ref().len() - 1;
let loc = Location { pos: vec![], parent_binder: None, binder_depth: 0 };
match_is_valid_core(root_idx, loc, &mut ctx)
}

fn match_is_valid_core(idx: usize, loc: Location, ctx: &mut Context) -> bool {
match &ctx.pat.as_ref()[idx] {
ENodeOrVar::Var(var) => {
let loose_bvars = &ctx.graph[ctx.subst[*var]].data.loose_bvars;
if loose_bvars.is_empty() {
// If the given variable does not map to an e-class containing loose bvars, if cannot cause any problems.
return true
} else {
// If the given variable maps to an e-class containing loose bvars, we need to check Conditions (1) and (2).

// For Condition (1): If the variable maps to an e-class containing bound variables whose indices are
// exceeded by the current binder depth, then those bound variables must be non-loose in the context
// of `ctx.pat`, and are not allowed to be matched.
if loose_bvars.iter().any(|b| *b < loc.binder_depth) { return false }

if let Some(required_parent) = ctx.parent_binders.get(var) {
// For Condition (2): If the variable has already occured elsewhere in the pattern,
// then the parent binder of that occurrence must be the same as the current parent binder.
return loc.parent_binder == *required_parent
} else {
// If the given variable has not been visited yet, record its parent binder.
ctx.parent_binders.insert(*var, loc.parent_binder);
return true
}
}
},
ENodeOrVar::ENode(e) => {
for (i, child) in e.children().iter().enumerate() {
// If `e` is a binder, set the parent binder and increase the binder depth for its body.
let (parent_binder, binder_depth) =
if is_binder(&e) && i == 1 {
(Some(loc.pos.clone()), loc.binder_depth + 1)
} else {
(loc.parent_binder.clone(), loc.binder_depth)
};
let mut pos = loc.pos.clone(); pos.push(i);
let child_loc = Location { pos, parent_binder, binder_depth };

let child_idx = usize::from(*child);
if !match_is_valid_core(child_idx, child_loc, ctx) { return false }
}
true
}
}
}

0 comments on commit b2d3e98

Please sign in to comment.