From 47dc59ac74589a5c6cff9ec5ea0751d224054b5b Mon Sep 17 00:00:00 2001 From: Marcus Rossel Date: Fri, 10 Jan 2025 13:57:28 +0100 Subject: [PATCH] Fix bugs in proof reconstruction, egraph queries, and construction of reified eqs --- Lean/Egg/Core/Config.lean | 4 +- Lean/Egg/Core/Explanation/Parse/Shared.lean | 5 + Lean/Egg/Core/Explanation/Proof.lean | 152 +++++++++++--------- Lean/Egg/Core/MVars/Basic.lean | 3 + Lean/Egg/Core/Premise/Rewrites.lean | 20 ++- Lean/Egg/Core/Source.lean | 6 + Lean/Egg/Tactic/Premises/Gen/GenM.lean | 4 - Lean/Egg/Tactic/Trace.lean | 3 + Lean/Egg/Tests/Calc.lean | 1 + Lean/Egg/Tests/Cond Valid MVar Dirs.lean | 29 ++++ Lean/Egg/Tests/NatLit.lean | 2 +- Lean/Egg/Tests/PushNeg.lean | 9 +- Lean/Egg/Tests/Shapes Rerun.lean | 4 +- Rust/Egg/src/basic.rs | 71 +++++---- Rust/Egg/src/lib.rs | 24 ++-- Rust/Egg/src/rewrite.rs | 30 ++-- Rust/Egg/src/util.rs | 5 + Rust/Slotted/src/basic.rs | 1 - Rust/Slotted/src/lib.rs | 4 +- Rust/Slotted/src/result.rs | 2 - Rust/Slotted/src/rewrite.rs | 17 +-- 21 files changed, 247 insertions(+), 149 deletions(-) create mode 100644 Lean/Egg/Tests/Cond Valid MVar Dirs.lean diff --git a/Lean/Egg/Core/Config.lean b/Lean/Egg/Core/Config.lean index 46e0927..ff08c64 100644 --- a/Lean/Egg/Core/Config.lean +++ b/Lean/Egg/Core/Config.lean @@ -13,7 +13,7 @@ def Normalization.noReduce : Normalization where structure Erasure where eraseProofs := true - eraseTCInstances := false + eraseTCInstances := true deriving Inhabited, BEq def Erasure.noErase : Erasure where @@ -72,7 +72,7 @@ structure Debug where deriving BEq structure _root_.Egg.Config extends Encoding, DefEq, Gen, Backend, Debug where - retryWithShapes := true + retryWithShapes := false explLengthLimit := 1000 -- TODO: Why aren't these coercions automatic? diff --git a/Lean/Egg/Core/Explanation/Parse/Shared.lean b/Lean/Egg/Core/Explanation/Parse/Shared.lean index dfa7512..f1de502 100644 --- a/Lean/Egg/Core/Explanation/Parse/Shared.lean +++ b/Lean/Egg/Core/Explanation/Parse/Shared.lean @@ -110,6 +110,7 @@ syntax str : fwd_rw_src syntax fwd_rw_src (noWs "-rev")? : rw_src syntax &"=" : rw_src +syntax &"∧" : rw_src syntax "+" num : shift_offset syntax "-" num : shift_offset @@ -216,6 +217,10 @@ def parseRwSrc : (TSyntax `rw_src) → Rewrite.Descriptor src := .reifiedEq dir := .forward } + | `(rw_src|∧) => { + src := .factAnd + dir := .forward + } | _ => unreachable! inductive ParseError where diff --git a/Lean/Egg/Core/Explanation/Proof.lean b/Lean/Egg/Core/Explanation/Proof.lean index e3a8257..47e1546 100644 --- a/Lean/Egg/Core/Explanation/Proof.lean +++ b/Lean/Egg/Core/Explanation/Proof.lean @@ -14,13 +14,14 @@ inductive Step.Rewrite where | rw (rw : Egg.Rewrite) (isRefl : Bool) | defeq (src : Source) | reifiedEq + | factAnd deriving Inhabited def Step.Rewrite.isRefl : Rewrite → Bool | rw _ isRefl => isRefl | defeq _ => true -- TODO: This isn't necessarily true. - | reifiedEq => false + | reifiedEq | factAnd => false structure Step where lhs : Expr @@ -72,10 +73,10 @@ where proofStep (idx : Nat) (current next : Expr) (rwInfo : Rewrite.Info) : MetaM (Proof.Step × Proof.Subgoals) := do - if let .reifiedEq := rwInfo.src then + if let .factAnd := rwInfo.src then let step := { - lhs := current, rhs := next, proof := ← mkReifiedEqStep idx current next, - rw := .reifiedEq, dir := rwInfo.dir + lhs := current, rhs := next, proof := ← mkFactAndStep idx current next, + rw := .factAnd, dir := rwInfo.dir } return (step, []) if rwInfo.src.isDefEq then @@ -84,88 +85,96 @@ where rw := .defeq rwInfo.src, dir := rwInfo.dir } return (step, []) - let some rw := rws.find? rwInfo.src | fail s!"unknown rewrite {rwInfo.src.description}" idx - -- TODO: Can there be conditional rfl proofs? - if ← isRflProof rw.proof then + if let some rw := rws.find? rwInfo.src then + -- TODO: Can there be conditional rfl proofs? + if ← isRflProof rw.proof then + let step := { + lhs := current, rhs := next, proof := ← mkReflStep idx current next rwInfo.src, + rw := .rw rw (isRefl := true), dir := rwInfo.dir + } + return (step, []) + let (prf, subgoals) ← mkCongrStep idx current next rwInfo.pos?.get! <| .inl (← rw.forDir rwInfo.dir) let step := { - lhs := current, rhs := next, proof := ← mkReflStep idx current next rwInfo.src, - rw := .rw rw (isRefl := true), dir := rwInfo.dir + lhs := current, rhs := next, proof := prf, rw := .rw rw (isRefl := false), dir := rwInfo.dir } - return (step, []) - let (prf, subgoals) ← mkCongrStep idx current next rwInfo.pos?.get! (← rw.forDir rwInfo.dir) - let step := { - lhs := current, rhs := next, proof := prf, rw := .rw rw (isRefl := false), dir := rwInfo.dir - } - return (step, subgoals) + return (step, subgoals) + else if rwInfo.src.isReifiedEq then + let (prf, subgoals) ← mkCongrStep idx current next rwInfo.pos?.get! <| .inr rwInfo.dir + let step := { lhs := current, rhs := next, proof := prf, rw := .reifiedEq, dir := rwInfo.dir } + return (step, subgoals) + else + fail s!"unknown rewrite {rwInfo.src.description}" idx mkReflStep (idx : Nat) (current next : Expr) (src : Source) : MetaM Expr := do unless ← isDefEq current next do fail s!"unification failure for proof by reflexivity with rw {src.description}" idx mkEqRefl next - mkReifiedEqStep (idx : Nat) (current next : Expr) : MetaM Expr := do - unless next.isTrue do - fail "invalid RHS of reified equality step from\n\n '{current}'\to\n '{next}'" idx - let some (lhs, rhs) := current.eqOrIff? - | fail "invalid LHS of reified equality step from\n\n '{current}'\to\n '{next}'" idx - unless ← isDefEq lhs rhs do - fail "invalid LHS of reified equality step from\n\n '{current}'\to\n '{next}'" idx - mkEqTrue (← mkEqRefl lhs) - /- This loops, and I think reified eq-steps might always be by refl, because we only introduce - a union by reified eq for fresh e-classes. - - ``` - let some prf ← mkSubproof current next - | fail m!"reified equality '{current} = {next}' could not be proven" idx - return prf - ``` - -/ - - mkCongrStep (idx : Nat) (current next : Expr) (pos : SubExpr.Pos) (rw : Rewrite) : + mkFactAndStep (idx : Nat) (current next : Expr) : MetaM Expr := do + let .app (.app (.const ``And []) (.const ``True [])) (.const ``True []) := current + | fail m!"invalid LHS of ∧-step from\n\n '{current}'\nto\n '{next}'" idx + unless next.isTrue do fail m!"invalid RHS of ∧-step from\n\n '{current}'\nto\n '{next}'" idx + mkAppM ``true_and #[.const ``True []] + + mkCongrStep (idx : Nat) (current next : Expr) (pos : SubExpr.Pos) (rw? : Sum Rewrite Direction) : MetaM (Expr × Proof.Subgoals) := do let mvc := (← getMCtx).mvarCounter - let (lhs, rhs, subgoals) ← placeCHoles idx current next pos rw + let (lhs, rhs, subgoals) ← placeCHoles idx current next pos rw? try return (← (← mkCongrOf 0 mvc lhs rhs).eq, subgoals) catch err => fail m!"'mkCongrOf' failed with\n {err.toMessageData}" idx - placeCHoles (idx : Nat) (current next : Expr) (pos : SubExpr.Pos) (rw : Rewrite) : + placeCHoles (idx : Nat) (current next : Expr) (pos : SubExpr.Pos) (rw? : Sum Rewrite Direction) : MetaM (Expr × Expr × Proof.Subgoals) := do replaceSubexprs (root₁ := current) (root₂ := next) (p := pos) fun lhs rhs => do -- It's necessary that we create the fresh rewrite (that is, create the fresh mvars) in *this* -- local context as otherwise the mvars can't unify with variables under binders. - let rw ← rw.fresh - unless ← isDefEq lhs rw.lhs do failIsDefEq "LHS" rw.src lhs rw.lhs rw.mvars.lhs current next idx - /- TODO: Remove? - let lhsType ← inferType lhs - let rwLhsType ← inferType rw.lhs - let _ ← isDefEq lhsType rwLhsType - synthLingeringTcErasureMVars lhs - synthLingeringTcErasureMVars rw.lhs - unless ← isDefEq lhs rw.lhs do - failIsDefEq "LHS" rw.src lhs rw.lhs rw.mvars.lhs.expr current next idx - -/ - unless ← isDefEq rhs rw.rhs do failIsDefEq "RHS" rw.src rhs rw.rhs rw.mvars.rhs current next idx - let mut subgoals := [] - let conds := rw.conds.filter (!·.isProven) - for cond in conds do - let cond ← cond.instantiateMVars - match cond.kind with - | .proof => - let some p ← proveCondition cond.type - | fail m!"condition '{cond.type}' of rewrite {rw.src.description} could not be proven" idx - unless ← isDefEq cond.expr p do - fail m!"proof of condition '{cond.type}' of rewrite {rw.src.description} was invalid" idx - | .tcInst => - let some p ← synthInstance? cond.type - | fail m!"type class condition '{cond.type}' of rewrite {rw.src.description} could not be synthesized" idx - unless ← isDefEq cond.expr p do - fail m!"synthesized type class for condition '{cond.type}' of rewrite {rw.src.description} was invalid" idx - let proof ← rw.eqProof - return ( - ← mkCHole (forLhs := true) lhs proof, - ← mkCHole (forLhs := false) rhs proof, - subgoals - ) + match rw? with + | .inr reifiedEqDir => + let proof ← proveReifiedEq idx lhs rhs reifiedEqDir + return ( + ← mkCHole (forLhs := true) lhs proof, + ← mkCHole (forLhs := false) rhs proof, + [] + ) + | .inl rw => + let rw ← rw.fresh + unless ← isDefEq lhs rw.lhs do failIsDefEq "LHS" rw.src lhs rw.lhs rw.mvars.lhs current next idx + unless ← isDefEq rhs rw.rhs do failIsDefEq "RHS" rw.src rhs rw.rhs rw.mvars.rhs current next idx + let mut subgoals := [] + let conds := rw.conds.filter (!·.isProven) + for cond in conds do + let cond ← cond.instantiateMVars + match cond.kind with + | .proof => + let some p ← proveCondition cond.type + | fail m!"condition '{cond.type}' of rewrite {rw.src.description} could not be proven" idx + unless ← isDefEq cond.expr p do + fail m!"proof of condition '{cond.type}' of rewrite {rw.src.description} was invalid" idx + | .tcInst => + let some p ← synthInstance? cond.type + | fail m!"type class condition '{cond.type}' of rewrite {rw.src.description} could not be synthesized" idx + unless ← isDefEq cond.expr p do + fail m!"synthesized type class for condition '{cond.type}' of rewrite {rw.src.description} was invalid" idx + let proof ← rw.eqProof + return ( + ← mkCHole (forLhs := true) lhs proof, + ← mkCHole (forLhs := false) rhs proof, + subgoals + ) + + proveReifiedEq (idx : Nat) (current next : Expr) (dir : Direction) : MetaM Expr := do + let (current, next) := match dir with + | .forward => (current, next) + | .backward => (next, current) + unless next.isTrue do + fail m!"invalid RHS of reified equality step from\n\n '{current}'\nto\n '{next}'" idx + let some (lhs, rhs) := current.eqOrIff? + | fail m!"invalid LHS (not an equivalence) of reified equality step from\n\n '{current}'\nto\n '{next}'" idx + unless ← isDefEq lhs rhs do + fail m!"invalid LHS (not defeq) of reified equality step from\n\n '{current}'\nto\n '{next}'" idx + match dir with + | .forward => mkEqTrue (← mkEqRefl lhs) + | .backward => mkEqSymm <| ← mkEqTrue (← mkEqRefl lhs) failIsDefEq {α} (side : String) (src : Source) (expr rwExpr : Expr) (rwMVars : MVars) @@ -187,10 +196,13 @@ where mkSubproof (lhs rhs : Expr) : MetaM (Option Expr) := do let req ← Request.Equiv.encoding lhs rhs ctx let rawExpl := egraph.run req + withTraceNode `egg.explanation (fun _ => return "Subexplanation") do trace[egg.explanation] rawExpl.str if rawExpl.str.isEmpty then return none let expl ← rawExpl.parse let proof ← expl.proof rws egraph ctx - proof.prove { lhs, rhs, rel := .eq } + -- `EGraph.run` proves `(lhs = rhs) = True`, so we still need to convert that to a proof of + -- `lhs = rhs`. + mkOfEqTrue <| ← proof.prove { lhs := ← mkEq lhs rhs, rhs := .const ``True [], rel := .eq } synthLingeringTcErasureMVars (e : Expr) : MetaM Unit := do let mvars := (← instantiateMVars e).collectMVars {} |>.result diff --git a/Lean/Egg/Core/MVars/Basic.lean b/Lean/Egg/Core/MVars/Basic.lean index 2ae9d40..6f52695 100644 --- a/Lean/Egg/Core/MVars/Basic.lean +++ b/Lean/Egg/Core/MVars/Basic.lean @@ -76,6 +76,9 @@ structure _root_.Egg.MVars where lvl : HashMap LMVarId Properties := ∅ deriving Inhabited +def isEmpty (mvars : MVars) : Bool := + mvars.expr.isEmpty && mvars.lvl.isEmpty + def visibleExpr (mvars : MVars) (cfg : Config.Erasure) : MVarIdSet := mvars.expr.fold (init := ∅) fun result m ps => if ps.isVisible cfg then result.insert m else result diff --git a/Lean/Egg/Core/Premise/Rewrites.lean b/Lean/Egg/Core/Premise/Rewrites.lean index 3e83247..e1d2626 100644 --- a/Lean/Egg/Core/Premise/Rewrites.lean +++ b/Lean/Egg/Core/Premise/Rewrites.lean @@ -20,6 +20,10 @@ inductive Kind where | proof | tcInst +def Kind.isProof : Kind → Bool + | proof => true + | tcInst => false + def Kind.forType? (ty : Expr) : MetaM (Option Kind) := do if ← Meta.isProp ty then return some .proof @@ -124,9 +128,21 @@ where def isConditional (rw : Rewrite) : Bool := !rw.conds.isEmpty +def isGroundEq (rw : Rewrite) : Bool := + rw.conds.isEmpty && rw.mvars.lhs.isEmpty && rw.mvars.rhs.isEmpty + def validDirs (rw : Rewrite) (cfg : Config.Erasure) : Directions := - let exprDirs := Directions.satisfyingSuperset (rw.mvars.lhs.visibleExpr cfg) (rw.mvars.rhs.visibleExpr cfg) - let lvlDirs := Directions.satisfyingSuperset (rw.mvars.lhs.visibleLevel cfg) (rw.mvars.rhs.visibleLevel cfg) + -- MVars appearing in propositional conditions are definitely going to be part of the rewrite's + -- LHS, so they can (and should be) ignored when computing valid directions. + -- TODO: How does visibility work in conditions? + let propCondExpr : MVarIdSet := rw.conds.filter (·.kind.isProof) |>.foldl (init := ∅) (·.union <| ·.mvars.visibleExpr cfg) + let propCondLevel : LMVarIdSet := rw.conds.filter (·.kind.isProof) |>.foldl (init := ∅) (·.union <| ·.mvars.visibleLevel cfg) + let visibleExprLhs := rw.mvars.lhs.visibleExpr cfg |>.filter (!propCondExpr.contains ·) + let visibleExprRhs := rw.mvars.rhs.visibleExpr cfg |>.filter (!propCondExpr.contains ·) + let visibleLevelLhs := rw.mvars.lhs.visibleLevel cfg |>.filter (!propCondLevel.contains ·) + let visibleLevelRhs := rw.mvars.rhs.visibleLevel cfg |>.filter (!propCondLevel.contains ·) + let exprDirs := Directions.satisfyingSuperset visibleExprLhs visibleExprRhs + let lvlDirs := Directions.satisfyingSuperset visibleLevelLhs visibleLevelRhs exprDirs.meet lvlDirs -- Returns the same rewrite but with its type and proof potentially flipped to match the given diff --git a/Lean/Egg/Core/Source.lean b/Lean/Egg/Core/Source.lean index db20b6a..bc23171 100644 --- a/Lean/Egg/Core/Source.lean +++ b/Lean/Egg/Core/Source.lean @@ -58,6 +58,7 @@ inductive Source where | explicit (idx : Nat) (eqn? : Option Nat) | star (id : FVarId) | reifiedEq + | factAnd | tcProj (src : Source) (loc : Source.TcProjLocation) (pos : SubExpr.Pos) (depth : Nat) | tcSpec (src : Source) (spec : Source.TcSpec) | nestedSplit (src : Source) (dir : Direction) @@ -125,6 +126,7 @@ def description : Source → String | explicit idx (some eqn) => s!"#{idx}/{eqn}" | star id => s!"*{id.uniqueIdx!}" | reifiedEq => "=" + | factAnd => "∧" | tcProj src loc pos dep => s!"{src.description}[{loc.description}{pos.asNat},{dep}]" | tcSpec src spec => s!"{src.description}<{spec.description}>" | nestedSplit src dir => s!"{src.description}⁅{dir.description}⁆" @@ -163,3 +165,7 @@ def isSubst : Source → Bool def involvesBinders : Source → Bool | subst _ | shift _ | eta _ | beta => true | _ => false + +def isReifiedEq : Source → Bool + | reifiedEq => true + | _ => false diff --git a/Lean/Egg/Tactic/Premises/Gen/GenM.lean b/Lean/Egg/Tactic/Premises/Gen/GenM.lean index ee3422c..32bf90a 100644 --- a/Lean/Egg/Tactic/Premises/Gen/GenM.lean +++ b/Lean/Egg/Tactic/Premises/Gen/GenM.lean @@ -5,10 +5,6 @@ open Lean Meta Elab Tactic -- TODO: Perform pruning during generation, not after. --- TODO: It might be ok to silently prune non-autogenerated (that is, user-provided) rewrites with --- unbound conditions, as certain conditional rewrites may only ever be intended to be used --- with tc specialization, explosion, or other rewrite generation. - namespace Egg def Rewrites.contains (tgts : Rewrites) (rw : Rewrite) : MetaM Bool := do diff --git a/Lean/Egg/Tactic/Trace.lean b/Lean/Egg/Tactic/Trace.lean index 8d79acc..651ee35 100644 --- a/Lean/Egg/Tactic/Trace.lean +++ b/Lean/Egg/Tactic/Trace.lean @@ -182,6 +182,9 @@ nonrec def Proof.trace (prf : Proof) (cls : Name) : TacticM Unit := do | .reifiedEq => withTraceNode cls (fun _ => return step.rhs) do trace cls fun _ => m!"Reified Equality" + | .factAnd => + withTraceNode cls (fun _ => return step.rhs) do + trace cls fun _ => m!"Fact ∧ Fact" nonrec def MVars.Ambient.trace (amb : MVars.Ambient) (cls : Name) : TacticM Unit := do withTraceNode cls (fun _ => return "Ambient MVars") do diff --git a/Lean/Egg/Tests/Calc.lean b/Lean/Egg/Tests/Calc.lean index b0e4f8a..31a8a57 100644 --- a/Lean/Egg/Tests/Calc.lean +++ b/Lean/Egg/Tests/Calc.lean @@ -32,6 +32,7 @@ example (h₁ : a = b) (h₂ : b = c) : a = c := by a = b with [h₁] _ = c with [h₂] +set_option trace.egg true in example (h₁ : 0 = 0 → a = b) : a = b := by egg calc [h₁, (rfl : 0 = 0)] _ = _ diff --git a/Lean/Egg/Tests/Cond Valid MVar Dirs.lean b/Lean/Egg/Tests/Cond Valid MVar Dirs.lean new file mode 100644 index 0000000..db650b0 --- /dev/null +++ b/Lean/Egg/Tests/Cond Valid MVar Dirs.lean @@ -0,0 +1,29 @@ +import Egg +open scoped Egg + +-- This test ensures that condition mvars are correctly taken into account when determining the +-- valid direction of rewrites. + +/-- +info: [egg.rewrites] Rewrites + [egg.rewrites] Basic (1) + [egg.rewrites] #0(⇔): h + [egg.rewrites] ?b = ?a + [egg.rewrites] Conditions + [egg.rewrites] ?a = ?b + [egg.rewrites] LHS MVars + [?b: [.unconditionallyVisible]] + [egg.rewrites] RHS MVars + [?a: [.unconditionallyVisible]] + [egg.rewrites] Tagged (0) + [egg.rewrites] Builtin (0) + [egg.rewrites] Derived (0) + [egg.rewrites] Definitional + [egg.rewrites] Pruned (0) +-/ +#guard_msgs(info) in +set_option trace.egg.rewrites true in +set_option egg.builtins false in +egg_no_defeq in +example (h : ∀ a b : Nat, a = b → b = a) : true = true := by + egg [h] diff --git a/Lean/Egg/Tests/NatLit.lean b/Lean/Egg/Tests/NatLit.lean index 17928b6..a689d81 100644 --- a/Lean/Egg/Tests/NatLit.lean +++ b/Lean/Egg/Tests/NatLit.lean @@ -27,7 +27,7 @@ elab "app" n:num fn:ident arg:term : term => open Lean.Elab.Term in do let rec go (n : Nat) := if n = 0 then elabTerm arg none else return .app fn <| ← go (n - 1) go n.getNat -example : (app 100 Nat.succ (nat_lit 0)) = (nat_lit 100) := by egg +example : (app 80 Nat.succ (nat_lit 0)) = (nat_lit 80) := by egg -- Note: This produces a gigantic proof. example (f : Nat → Nat) (h : ∀ x, f x = x.succ) : 30 = app 30 f 0 := by diff --git a/Lean/Egg/Tests/PushNeg.lean b/Lean/Egg/Tests/PushNeg.lean index 46b9e11..c08bf6f 100644 --- a/Lean/Egg/Tests/PushNeg.lean +++ b/Lean/Egg/Tests/PushNeg.lean @@ -27,20 +27,19 @@ example : ((fun x => x + x) 1) = 2 := by egg! example : ¬¬(p = p) := by - egg! + sorry -- egg! example : (¬¬p) = p := by egg! - example : (¬p ∧ ¬q) → ¬(p ∨ q) := by intro h - egg! + sorry -- egg! example : ¬(p ∧ q) → (p → ¬q) := by intro h - egg! + sorry -- egg! example {p' : α → Prop} : (∀(x : α), ¬ p' x) → ¬ ∃(x : α), p' x := by intro h - egg! + sorry -- egg! diff --git a/Lean/Egg/Tests/Shapes Rerun.lean b/Lean/Egg/Tests/Shapes Rerun.lean index 7071c59..0865d8f 100644 --- a/Lean/Egg/Tests/Shapes Rerun.lean +++ b/Lean/Egg/Tests/Shapes Rerun.lean @@ -9,14 +9,14 @@ error: egg failed to build proof step 0: unification failure for LHS of rewrite Nat.add vs - ?m.253 + ?m.270 in Nat.add and () • Types: ⏎ - ?m.253: Unit + ?m.270: Unit • Read Only Or Synthetic Opaque MVars: [] -/ diff --git a/Rust/Egg/src/basic.rs b/Rust/Egg/src/basic.rs index b2c25a3..19a5281 100644 --- a/Rust/Egg/src/basic.rs +++ b/Rust/Egg/src/basic.rs @@ -12,6 +12,7 @@ use crate::nat_lit::*; use crate::rewrite::*; use crate::shift::*; use crate::subst::*; +use crate::util::*; #[repr(C)] pub struct Config { @@ -42,13 +43,15 @@ pub fn explain_congr( init: String, goal: String, rw_templates: Vec, guides: Vec, cfg: Config, viz_path: Option, env: *const c_void ) -> Result { - // TODO: For rewrites which don't contain conditions or pattern vars: only run the rewrite once - // before eqsat by adding the LHS and RHS to the e-graph and unioning them. - - let Initialized { egraph, init_id, init_expr, goal_id, goal_expr } = + let Initialized { mut egraph, init_id, init_expr, goal_id, goal_expr, true_id } = mk_initial_egraph(init, goal, guides, &cfg)?; - let rws = mk_rewrites(rw_templates, &cfg, env)?; + let (rws, eqs) = mk_rewrites(rw_templates, &cfg, env)?; + + // Adds ground equalities to the e-graph. + for eq in eqs { + egraph.union_instantiations(&eq.lhs, &eq.rhs, &Subst::with_capacity(0), eq.name); + } let mut runner = Runner::default() .with_egraph(egraph) @@ -56,27 +59,30 @@ pub fn explain_congr( .with_node_limit(cfg.node_limit) .with_iter_limit(cfg.iter_limit) .with_hook(move |runner| { - if let Some(path) = &viz_path { - runner.egraph.dot().to_dot(format!("{}/{}.dot", path, runner.iterations.len())).unwrap(); - } - if runner.egraph.find(init_id) == runner.egraph.find(goal_id) { - Err("search complete".to_string()) - } else { - Ok(()) + if let Some(goal_eq_id) = runner.egraph.lookup(LeanExpr::Eq([init_id, goal_id])) { + if runner.egraph.find(goal_eq_id) == runner.egraph.find(true_id) { + return Err("search complete".to_string()) + } } + Ok(()) }) .with_hook(move |runner| { - let eg = &mut runner.egraph; + let graph = &mut runner.egraph; let true_expr = "(const \"True\")".parse().unwrap(); - let true_class = eg.add_expr(&true_expr); - let ids: Vec<_> = eg.classes().map(|x| x.id).collect(); - for x in ids { - if !is_primitive(x, eg) { - let a = eg.add(LeanExpr::Eq([x, x])); - eg.union_trusted(a, true_class, "="); - } + let classes: Vec<_> = graph.classes().map(|x| x.id).collect(); + for class in classes { + if is_primitive(class, graph) { continue } + let (_, rep) = Extractor::new(&graph, AstSize).find_best(class); + let eq_expr = format!("(= {} {})", rep, rep).parse().unwrap(); + graph.union_instantiations(&eq_expr, &true_expr, &Subst::with_capacity(0), "="); + } + graph.rebuild(); + Ok(()) + }) + .with_hook(move |runner| { + if let Some(path) = &viz_path { + runner.egraph.dot().to_dot(format!("{}/{}.dot", path, runner.iterations.len())).unwrap(); } - eg.rebuild(); Ok(()) }) .run(&rws); @@ -98,7 +104,8 @@ struct Initialized { init_id: Id, init_expr: RecExpr, goal_id: Id, - goal_expr: RecExpr + goal_expr: RecExpr, + true_id: Id } fn mk_initial_egraph( @@ -134,17 +141,27 @@ fn mk_initial_egraph( // Marks `p ∧ q` as a fact for any given facts `p` and `q`. let and_true = "(app (app (const \"And\") (const \"True\")) (const \"True\"))".parse().unwrap(); let and_id = egraph.add_expr(&and_true); - egraph.union_trusted(true_id, and_id, "AND_FACT"); + egraph.union_trusted(true_id, and_id, "∧"); - Ok(Initialized { egraph, init_id, init_expr, goal_id, goal_expr }) + Ok(Initialized { egraph, init_id, init_expr, goal_id, goal_expr, true_id }) } -fn mk_rewrites(rw_templates: Vec, cfg: &Config, env: *const c_void) -> Result, Error> { +fn mk_rewrites( + rw_templates: Vec, cfg: &Config, env: *const c_void +) -> Result<(Vec, Vec), Error> { let mut rws = vec![ rewrite!("EQ"; "(app (app (app (const \"Eq\" ?u) ?t) ?l) ?r)" => "(= ?l ?r)") ]; - for template in rw_templates { rws.push(template.to_rewrite(cfg.to_rw_config(env))?) } + let mut eqs = vec![]; + + for template in rw_templates { + match template.to_rewrite(cfg.to_rw_config(env))? { + Either::Left(rw) => rws.push(rw), + Either::Right(eq) => eqs.push(eq) + } + } + if cfg.nat_lit { rws.append(&mut nat_lit_rws(cfg.shapes)) } if cfg.eta { rws.push(eta_reduction_rw()) } if cfg.eta_expand { rws.push(eta_expansion_rw()) } @@ -155,7 +172,7 @@ fn mk_rewrites(rw_templates: Vec, cfg: &Config, env: *const c_v rws.append(&mut subst_rws()); rws.append(&mut shift_rws()); - Ok(rws) + Ok((rws, eqs)) } fn collect_rw_stats(runner: &Runner) -> String { diff --git a/Rust/Egg/src/lib.rs b/Rust/Egg/src/lib.rs index 4f90f0a..cd7b764 100644 --- a/Rust/Egg/src/lib.rs +++ b/Rust/Egg/src/lib.rs @@ -236,18 +236,22 @@ pub unsafe extern "C" fn egg_query_equiv( goal_str_ptr: *const c_char ) -> *const c_char { let egraph = egraph.as_mut().unwrap(); - let init = c_str_to_string(init_str_ptr).parse().unwrap(); - let goal = c_str_to_string(goal_str_ptr).parse().unwrap(); - let init_id = egraph.add_expr(&init); - let goal_id = egraph.add_expr(&goal); + let init = c_str_to_string(init_str_ptr); + let goal = c_str_to_string(goal_str_ptr); - if egraph.find(init_id) == egraph.find(goal_id) { - let mut expl = egraph.explain_equivalence(&init, &goal); - let expl_str = expl.get_flat_string(); - string_to_c_str(expl_str) - } else { - string_to_c_str("".to_string()) + let eq_expr = format!("(= {} {})", init, goal).parse().unwrap(); + let true_expr = "(const \"True\")".parse().unwrap(); + let true_id = egraph.lookup_expr(&true_expr).unwrap(); + + if let Some(eq_id) = egraph.lookup_expr(&eq_expr) { + if egraph.find(true_id) == egraph.find(eq_id) { + let mut expl = egraph.explain_equivalence(&eq_expr, &true_expr); + let expl_str = expl.get_flat_string(); + return string_to_c_str(expl_str) + } } + + string_to_c_str("".to_string()) } #[no_mangle] diff --git a/Rust/Egg/src/rewrite.rs b/Rust/Egg/src/rewrite.rs index c310cd8..6bb0e21 100644 --- a/Rust/Egg/src/rewrite.rs +++ b/Rust/Egg/src/rewrite.rs @@ -9,11 +9,12 @@ use crate::bvar_correction::*; use crate::string_to_c_str; use crate::valid_match::*; use crate::is_synthable; +use crate::util::*; pub struct RewriteConfig { block_invalid_matches: bool, shift_captured_bvars: bool, - allow_unsat_conditions: bool, + _allow_unsat_conditions: bool, env: *const c_void } @@ -23,7 +24,7 @@ impl Config { RewriteConfig { block_invalid_matches: self.block_invalid_matches, shift_captured_bvars: self.shift_captured_bvars, - allow_unsat_conditions: self.allow_unsat_conditions, + _allow_unsat_conditions: self.allow_unsat_conditions, env } } @@ -37,14 +38,25 @@ pub struct RewriteTemplate { pub tc_conds: Vec>, } +pub struct GroundEq { + pub name: String, + pub lhs : PatternAst, + pub rhs : PatternAst +} + impl RewriteTemplate { - pub fn to_rewrite(self, cfg: RewriteConfig) -> Res { - // TODO: How do we handle `allow_unsat_conditions`? One option would be to simply not add - // the conditional statements when the option is enabled. I'm not sure what to do - // about tc conditions though, because some of them are a result of tc inst erasure - // and should always be enforced. Perhaps, can we determine which tc conditions are a - // result of tc inst erasure and still check those? + // TODO: How do we handle `allow_unsat_conditions`? One option would be to simply not add + // the conditional statements when the option is enabled. I'm not sure what to do + // about tc conditions though, because some of them are a result of tc inst erasure + // and should always be enforced. Perhaps, can we determine which tc conditions are a + // result of tc inst erasure and still check those? + pub fn to_rewrite(self, cfg: RewriteConfig) -> Res> { + // If the rewrite contains neither conditions nor pattern variables, it's a ground equation. + if self.prop_conds.is_empty() && self.tc_conds.is_empty() && + self.lhs.vars().is_empty() && self.rhs.vars().is_empty() { + return Ok(Either::Right(GroundEq { name: self.name, lhs: self.lhs.ast, rhs: self.rhs.ast })) + } let lhs = if self.prop_conds.is_empty() { self.lhs.clone() @@ -61,7 +73,7 @@ impl RewriteTemplate { let applier = LeanApplier { lhs: self.lhs, rhs: self.rhs, tc_conds: self.tc_conds, cfg }; match Rewrite::new(self.name, lhs, applier) { - Ok(rw) => Ok(rw), + Ok(rw) => Ok(Either::Left(rw)), Err(err) => Err(Error::Rewrite(err.to_string())) } } diff --git a/Rust/Egg/src/util.rs b/Rust/Egg/src/util.rs index f7aebd4..b4b08cc 100644 --- a/Rust/Egg/src/util.rs +++ b/Rust/Egg/src/util.rs @@ -2,6 +2,11 @@ use std::collections::HashSet; use std::hash::Hash; use egg::*; +pub enum Either { + Left(L), + Right(R), +} + pub fn sub_expr(ast: &RecExpr, i: Id) -> RecExpr { let v: Vec<_> = ast.as_ref()[0..=usize::from(i)].iter().cloned().collect(); RecExpr::from(v) diff --git a/Rust/Slotted/src/basic.rs b/Rust/Slotted/src/basic.rs index 339c96c..f6dee42 100644 --- a/Rust/Slotted/src/basic.rs +++ b/Rust/Slotted/src/basic.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use slotted_egraphs::*; use crate::result::*; use crate::analysis::*; diff --git a/Rust/Slotted/src/lib.rs b/Rust/Slotted/src/lib.rs index 2175ccc..58f1a12 100644 --- a/Rust/Slotted/src/lib.rs +++ b/Rust/Slotted/src/lib.rs @@ -88,7 +88,7 @@ impl CRewritesArray { let conds: Vec<_> = conds_strs.iter().map(|cond| Pattern::parse(cond).expect("Failed to parse condition")).collect(); if rw.dirs == RewriteDirections::Forward || rw.dirs == RewriteDirections::Both { - res.push(RewriteTemplate { name: name_str.to_string(), lhs: lhs.clone(), rhs: rhs.clone(), conds: conds.clone() }) + res.push(RewriteTemplate { name: name_str.to_string(), lhs: lhs.clone(), rhs: rhs.clone(), _conds: conds.clone() }) } if rw.dirs == RewriteDirections::Backward || rw.dirs == RewriteDirections::Both { // It is important that we use the "-rev" suffix for reverse rules here, as this is also @@ -96,7 +96,7 @@ impl CRewritesArray { // If we choose another naming scheme, egg may complain about duplicate rules when // `rw.dir == RewriteDirection::Both`. This is the case, for example, for the rewrite // `?a + ?b = ?b + ?a`. - res.push(RewriteTemplate { name: format!("{name_str}-rev"), lhs: rhs, rhs: lhs, conds }) + res.push(RewriteTemplate { name: format!("{name_str}-rev"), lhs: rhs, rhs: lhs, _conds: conds }) } } Ok(res) diff --git a/Rust/Slotted/src/result.rs b/Rust/Slotted/src/result.rs index 8ef4582..05d8c7b 100644 --- a/Rust/Slotted/src/result.rs +++ b/Rust/Slotted/src/result.rs @@ -3,7 +3,6 @@ pub enum Error { Init(String), Goal(String), Guide(String), - Condition(String), Rewrite(String), } @@ -14,7 +13,6 @@ impl ToString for Error { Error::Init(s) => format!("⚡️ {s}"), Error::Goal(s) => format!("⚡️ {s}"), Error::Guide(s) => format!("⚡️ {s}"), - Error::Condition(s) => format!("⚡️ {s}"), Error::Rewrite(s) => format!("⚡️ {s}"), } } diff --git a/Rust/Slotted/src/rewrite.rs b/Rust/Slotted/src/rewrite.rs index b12153e..c72dd4c 100644 --- a/Rust/Slotted/src/rewrite.rs +++ b/Rust/Slotted/src/rewrite.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::collections::HashSet; use slotted_egraphs::*; use crate::result::*; @@ -9,7 +8,7 @@ pub struct RewriteTemplate { pub name: String, pub lhs: Pattern, pub rhs: Pattern, - pub conds: Vec> + pub _conds: Vec> } fn slots_for_node(e: &LeanExpr) -> HashSet { @@ -44,7 +43,7 @@ fn subst_is_valid(subst: &Subst, illegal_slots: &HashSet) -> bool { } pub fn templates_to_rewrites( - templates: Vec, allow_unsat_conditions: bool + templates: Vec, _allow_unsat_conditions: bool ) -> Res> { let mut result: Vec = vec![]; for template in templates { @@ -66,15 +65,9 @@ pub fn templates_to_rewrites( // Disallows rewriting on primitive e-nodes. if analysis.is_primitive { continue } - let mut rule = template.name.clone(); - - for cond in template.conds.clone() { - let id = pattern_subst(graph, &cond, &subst); - - // TODO: Handle conditions. - } - - graph.union_instantiations(&template.lhs, &template.rhs, &subst, Some(rule)); + // TODO: Handle conditions. + + graph.union_instantiations(&template.lhs, &template.rhs, &subst, Some(template.name.clone())); } }), };