Skip to content

Commit

Permalink
Fix bugs in proof reconstruction, egraph queries, and construction of…
Browse files Browse the repository at this point in the history
… reified eqs
  • Loading branch information
marcusrossel committed Jan 10, 2025
1 parent 471bf09 commit 47dc59a
Show file tree
Hide file tree
Showing 21 changed files with 247 additions and 149 deletions.
4 changes: 2 additions & 2 deletions Lean/Egg/Core/Config.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def Normalization.noReduce : Normalization where

structure Erasure where
eraseProofs := true
eraseTCInstances := false
eraseTCInstances := true
deriving Inhabited, BEq

def Erasure.noErase : Erasure where
Expand Down Expand Up @@ -72,7 +72,7 @@ structure Debug where
deriving BEq

structure _root_.Egg.Config extends Encoding, DefEq, Gen, Backend, Debug where
retryWithShapes := true
retryWithShapes := false
explLengthLimit := 1000

-- TODO: Why aren't these coercions automatic?
Expand Down
5 changes: 5 additions & 0 deletions Lean/Egg/Core/Explanation/Parse/Shared.lean
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ syntax str : fwd_rw_src

syntax fwd_rw_src (noWs "-rev")? : rw_src
syntax &"=" : rw_src
syntax &"∧" : rw_src

syntax "+" num : shift_offset
syntax "-" num : shift_offset
Expand Down Expand Up @@ -216,6 +217,10 @@ def parseRwSrc : (TSyntax `rw_src) → Rewrite.Descriptor
src := .reifiedEq
dir := .forward
}
| `(rw_src|∧) => {
src := .factAnd
dir := .forward
}
| _ => unreachable!

inductive ParseError where
Expand Down
152 changes: 82 additions & 70 deletions Lean/Egg/Core/Explanation/Proof.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ inductive Step.Rewrite where
| rw (rw : Egg.Rewrite) (isRefl : Bool)
| defeq (src : Source)
| reifiedEq
| factAnd
deriving Inhabited

def Step.Rewrite.isRefl : Rewrite → Bool
| rw _ isRefl => isRefl
| defeq _ => true
-- TODO: This isn't necessarily true.
| reifiedEq => false
| reifiedEq | factAnd => false

structure Step where
lhs : Expr
Expand Down Expand Up @@ -72,10 +73,10 @@ where

proofStep (idx : Nat) (current next : Expr) (rwInfo : Rewrite.Info) :
MetaM (Proof.Step × Proof.Subgoals) := do
if let .reifiedEq := rwInfo.src then
if let .factAnd := rwInfo.src then
let step := {
lhs := current, rhs := next, proof := ← mkReifiedEqStep idx current next,
rw := .reifiedEq, dir := rwInfo.dir
lhs := current, rhs := next, proof := ← mkFactAndStep idx current next,
rw := .factAnd, dir := rwInfo.dir
}
return (step, [])
if rwInfo.src.isDefEq then
Expand All @@ -84,88 +85,96 @@ where
rw := .defeq rwInfo.src, dir := rwInfo.dir
}
return (step, [])
let some rw := rws.find? rwInfo.src | fail s!"unknown rewrite {rwInfo.src.description}" idx
-- TODO: Can there be conditional rfl proofs?
if ← isRflProof rw.proof then
if let some rw := rws.find? rwInfo.src then
-- TODO: Can there be conditional rfl proofs?
if ← isRflProof rw.proof then
let step := {
lhs := current, rhs := next, proof := ← mkReflStep idx current next rwInfo.src,
rw := .rw rw (isRefl := true), dir := rwInfo.dir
}
return (step, [])
let (prf, subgoals) ← mkCongrStep idx current next rwInfo.pos?.get! <| .inl (← rw.forDir rwInfo.dir)
let step := {
lhs := current, rhs := next, proof := ← mkReflStep idx current next rwInfo.src,
rw := .rw rw (isRefl := true), dir := rwInfo.dir
lhs := current, rhs := next, proof := prf, rw := .rw rw (isRefl := false), dir := rwInfo.dir
}
return (step, [])
let (prf, subgoals) ← mkCongrStep idx current next rwInfo.pos?.get! (← rw.forDir rwInfo.dir)
let step := {
lhs := current, rhs := next, proof := prf, rw := .rw rw (isRefl := false), dir := rwInfo.dir
}
return (step, subgoals)
return (step, subgoals)
else if rwInfo.src.isReifiedEq then
let (prf, subgoals) ← mkCongrStep idx current next rwInfo.pos?.get! <| .inr rwInfo.dir
let step := { lhs := current, rhs := next, proof := prf, rw := .reifiedEq, dir := rwInfo.dir }
return (step, subgoals)
else
fail s!"unknown rewrite {rwInfo.src.description}" idx

mkReflStep (idx : Nat) (current next : Expr) (src : Source) : MetaM Expr := do
unless ← isDefEq current next do
fail s!"unification failure for proof by reflexivity with rw {src.description}" idx
mkEqRefl next

mkReifiedEqStep (idx : Nat) (current next : Expr) : MetaM Expr := do
unless next.isTrue do
fail "invalid RHS of reified equality step from\n\n '{current}'\to\n '{next}'" idx
let some (lhs, rhs) := current.eqOrIff?
| fail "invalid LHS of reified equality step from\n\n '{current}'\to\n '{next}'" idx
unless ← isDefEq lhs rhs do
fail "invalid LHS of reified equality step from\n\n '{current}'\to\n '{next}'" idx
mkEqTrue (← mkEqRefl lhs)
/- This loops, and I think reified eq-steps might always be by refl, because we only introduce
a union by reified eq for fresh e-classes.
```
let some prf ← mkSubproof current next
| fail m!"reified equality '{current} = {next}' could not be proven" idx
return prf
```
-/

mkCongrStep (idx : Nat) (current next : Expr) (pos : SubExpr.Pos) (rw : Rewrite) :
mkFactAndStep (idx : Nat) (current next : Expr) : MetaM Expr := do
let .app (.app (.const ``And []) (.const ``True [])) (.const ``True []) := current
| fail m!"invalid LHS of ∧-step from\n\n '{current}'\nto\n '{next}'" idx
unless next.isTrue do fail m!"invalid RHS of ∧-step from\n\n '{current}'\nto\n '{next}'" idx
mkAppM ``true_and #[.const ``True []]

mkCongrStep (idx : Nat) (current next : Expr) (pos : SubExpr.Pos) (rw? : Sum Rewrite Direction) :
MetaM (Expr × Proof.Subgoals) := do
let mvc := (← getMCtx).mvarCounter
let (lhs, rhs, subgoals) ← placeCHoles idx current next pos rw
let (lhs, rhs, subgoals) ← placeCHoles idx current next pos rw?
try return (← (← mkCongrOf 0 mvc lhs rhs).eq, subgoals)
catch err => fail m!"'mkCongrOf' failed with\n {err.toMessageData}" idx

placeCHoles (idx : Nat) (current next : Expr) (pos : SubExpr.Pos) (rw : Rewrite) :
placeCHoles (idx : Nat) (current next : Expr) (pos : SubExpr.Pos) (rw? : Sum Rewrite Direction) :
MetaM (Expr × Expr × Proof.Subgoals) := do
replaceSubexprs (root₁ := current) (root₂ := next) (p := pos) fun lhs rhs => do
-- It's necessary that we create the fresh rewrite (that is, create the fresh mvars) in *this*
-- local context as otherwise the mvars can't unify with variables under binders.
let rw ← rw.fresh
unless ← isDefEq lhs rw.lhs do failIsDefEq "LHS" rw.src lhs rw.lhs rw.mvars.lhs current next idx
/- TODO: Remove?
let lhsType ← inferType lhs
let rwLhsType ← inferType rw.lhs
let _ ← isDefEq lhsType rwLhsType
synthLingeringTcErasureMVars lhs
synthLingeringTcErasureMVars rw.lhs
unless ← isDefEq lhs rw.lhs do
failIsDefEq "LHS" rw.src lhs rw.lhs rw.mvars.lhs.expr current next idx
-/
unless ← isDefEq rhs rw.rhs do failIsDefEq "RHS" rw.src rhs rw.rhs rw.mvars.rhs current next idx
let mut subgoals := []
let conds := rw.conds.filter (!·.isProven)
for cond in conds do
let cond ← cond.instantiateMVars
match cond.kind with
| .proof =>
let some p ← proveCondition cond.type
| fail m!"condition '{cond.type}' of rewrite {rw.src.description} could not be proven" idx
unless ← isDefEq cond.expr p do
fail m!"proof of condition '{cond.type}' of rewrite {rw.src.description} was invalid" idx
| .tcInst =>
let some p ← synthInstance? cond.type
| fail m!"type class condition '{cond.type}' of rewrite {rw.src.description} could not be synthesized" idx
unless ← isDefEq cond.expr p do
fail m!"synthesized type class for condition '{cond.type}' of rewrite {rw.src.description} was invalid" idx
let proof ← rw.eqProof
return (
← mkCHole (forLhs := true) lhs proof,
← mkCHole (forLhs := false) rhs proof,
subgoals
)
match rw? with
| .inr reifiedEqDir =>
let proof ← proveReifiedEq idx lhs rhs reifiedEqDir
return (
← mkCHole (forLhs := true) lhs proof,
← mkCHole (forLhs := false) rhs proof,
[]
)
| .inl rw =>
let rw ← rw.fresh
unless ← isDefEq lhs rw.lhs do failIsDefEq "LHS" rw.src lhs rw.lhs rw.mvars.lhs current next idx
unless ← isDefEq rhs rw.rhs do failIsDefEq "RHS" rw.src rhs rw.rhs rw.mvars.rhs current next idx
let mut subgoals := []
let conds := rw.conds.filter (!·.isProven)
for cond in conds do
let cond ← cond.instantiateMVars
match cond.kind with
| .proof =>
let some p ← proveCondition cond.type
| fail m!"condition '{cond.type}' of rewrite {rw.src.description} could not be proven" idx
unless ← isDefEq cond.expr p do
fail m!"proof of condition '{cond.type}' of rewrite {rw.src.description} was invalid" idx
| .tcInst =>
let some p ← synthInstance? cond.type
| fail m!"type class condition '{cond.type}' of rewrite {rw.src.description} could not be synthesized" idx
unless ← isDefEq cond.expr p do
fail m!"synthesized type class for condition '{cond.type}' of rewrite {rw.src.description} was invalid" idx
let proof ← rw.eqProof
return (
← mkCHole (forLhs := true) lhs proof,
← mkCHole (forLhs := false) rhs proof,
subgoals
)

proveReifiedEq (idx : Nat) (current next : Expr) (dir : Direction) : MetaM Expr := do
let (current, next) := match dir with
| .forward => (current, next)
| .backward => (next, current)
unless next.isTrue do
fail m!"invalid RHS of reified equality step from\n\n '{current}'\nto\n '{next}'" idx
let some (lhs, rhs) := current.eqOrIff?
| fail m!"invalid LHS (not an equivalence) of reified equality step from\n\n '{current}'\nto\n '{next}'" idx
unless ← isDefEq lhs rhs do
fail m!"invalid LHS (not defeq) of reified equality step from\n\n '{current}'\nto\n '{next}'" idx
match dir with
| .forward => mkEqTrue (← mkEqRefl lhs)
| .backward => mkEqSymm <| ← mkEqTrue (← mkEqRefl lhs)

failIsDefEq
{α} (side : String) (src : Source) (expr rwExpr : Expr) (rwMVars : MVars)
Expand All @@ -187,10 +196,13 @@ where
mkSubproof (lhs rhs : Expr) : MetaM (Option Expr) := do
let req ← Request.Equiv.encoding lhs rhs ctx
let rawExpl := egraph.run req
withTraceNode `egg.explanation (fun _ => return "Subexplanation") do trace[egg.explanation] rawExpl.str
if rawExpl.str.isEmpty then return none
let expl ← rawExpl.parse
let proof ← expl.proof rws egraph ctx
proof.prove { lhs, rhs, rel := .eq }
-- `EGraph.run` proves `(lhs = rhs) = True`, so we still need to convert that to a proof of
-- `lhs = rhs`.
mkOfEqTrue <| ← proof.prove { lhs := ← mkEq lhs rhs, rhs := .const ``True [], rel := .eq }

synthLingeringTcErasureMVars (e : Expr) : MetaM Unit := do
let mvars := (← instantiateMVars e).collectMVars {} |>.result
Expand Down
3 changes: 3 additions & 0 deletions Lean/Egg/Core/MVars/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ structure _root_.Egg.MVars where
lvl : HashMap LMVarId Properties := ∅
deriving Inhabited

def isEmpty (mvars : MVars) : Bool :=
mvars.expr.isEmpty && mvars.lvl.isEmpty

def visibleExpr (mvars : MVars) (cfg : Config.Erasure) : MVarIdSet :=
mvars.expr.fold (init := ∅) fun result m ps =>
if ps.isVisible cfg then result.insert m else result
Expand Down
20 changes: 18 additions & 2 deletions Lean/Egg/Core/Premise/Rewrites.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ inductive Kind where
| proof
| tcInst

def Kind.isProof : Kind → Bool
| proof => true
| tcInst => false

def Kind.forType? (ty : Expr) : MetaM (Option Kind) := do
if ← Meta.isProp ty then
return some .proof
Expand Down Expand Up @@ -124,9 +128,21 @@ where
def isConditional (rw : Rewrite) : Bool :=
!rw.conds.isEmpty

def isGroundEq (rw : Rewrite) : Bool :=
rw.conds.isEmpty && rw.mvars.lhs.isEmpty && rw.mvars.rhs.isEmpty

def validDirs (rw : Rewrite) (cfg : Config.Erasure) : Directions :=
let exprDirs := Directions.satisfyingSuperset (rw.mvars.lhs.visibleExpr cfg) (rw.mvars.rhs.visibleExpr cfg)
let lvlDirs := Directions.satisfyingSuperset (rw.mvars.lhs.visibleLevel cfg) (rw.mvars.rhs.visibleLevel cfg)
-- MVars appearing in propositional conditions are definitely going to be part of the rewrite's
-- LHS, so they can (and should be) ignored when computing valid directions.
-- TODO: How does visibility work in conditions?
let propCondExpr : MVarIdSet := rw.conds.filter (·.kind.isProof) |>.foldl (init := ∅) (·.union <| ·.mvars.visibleExpr cfg)
let propCondLevel : LMVarIdSet := rw.conds.filter (·.kind.isProof) |>.foldl (init := ∅) (·.union <| ·.mvars.visibleLevel cfg)
let visibleExprLhs := rw.mvars.lhs.visibleExpr cfg |>.filter (!propCondExpr.contains ·)
let visibleExprRhs := rw.mvars.rhs.visibleExpr cfg |>.filter (!propCondExpr.contains ·)
let visibleLevelLhs := rw.mvars.lhs.visibleLevel cfg |>.filter (!propCondLevel.contains ·)
let visibleLevelRhs := rw.mvars.rhs.visibleLevel cfg |>.filter (!propCondLevel.contains ·)
let exprDirs := Directions.satisfyingSuperset visibleExprLhs visibleExprRhs
let lvlDirs := Directions.satisfyingSuperset visibleLevelLhs visibleLevelRhs
exprDirs.meet lvlDirs

-- Returns the same rewrite but with its type and proof potentially flipped to match the given
Expand Down
6 changes: 6 additions & 0 deletions Lean/Egg/Core/Source.lean
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ inductive Source where
| explicit (idx : Nat) (eqn? : Option Nat)
| star (id : FVarId)
| reifiedEq
| factAnd
| tcProj (src : Source) (loc : Source.TcProjLocation) (pos : SubExpr.Pos) (depth : Nat)
| tcSpec (src : Source) (spec : Source.TcSpec)
| nestedSplit (src : Source) (dir : Direction)
Expand Down Expand Up @@ -125,6 +126,7 @@ def description : Source → String
| explicit idx (some eqn) => s!"#{idx}/{eqn}"
| star id => s!"*{id.uniqueIdx!}"
| reifiedEq => "="
| factAnd => "∧"
| tcProj src loc pos dep => s!"{src.description}[{loc.description}{pos.asNat},{dep}]"
| tcSpec src spec => s!"{src.description}<{spec.description}>"
| nestedSplit src dir => s!"{src.description}{dir.description}⁆"
Expand Down Expand Up @@ -163,3 +165,7 @@ def isSubst : Source → Bool
def involvesBinders : Source → Bool
| subst _ | shift _ | eta _ | beta => true
| _ => false

def isReifiedEq : Source → Bool
| reifiedEq => true
| _ => false
4 changes: 0 additions & 4 deletions Lean/Egg/Tactic/Premises/Gen/GenM.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@ open Lean Meta Elab Tactic

-- TODO: Perform pruning during generation, not after.

-- TODO: It might be ok to silently prune non-autogenerated (that is, user-provided) rewrites with
-- unbound conditions, as certain conditional rewrites may only ever be intended to be used
-- with tc specialization, explosion, or other rewrite generation.

namespace Egg

def Rewrites.contains (tgts : Rewrites) (rw : Rewrite) : MetaM Bool := do
Expand Down
3 changes: 3 additions & 0 deletions Lean/Egg/Tactic/Trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ nonrec def Proof.trace (prf : Proof) (cls : Name) : TacticM Unit := do
| .reifiedEq =>
withTraceNode cls (fun _ => return step.rhs) do
trace cls fun _ => m!"Reified Equality"
| .factAnd =>
withTraceNode cls (fun _ => return step.rhs) do
trace cls fun _ => m!"Fact ∧ Fact"

nonrec def MVars.Ambient.trace (amb : MVars.Ambient) (cls : Name) : TacticM Unit := do
withTraceNode cls (fun _ => return "Ambient MVars") do
Expand Down
1 change: 1 addition & 0 deletions Lean/Egg/Tests/Calc.lean
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ example (h₁ : a = b) (h₂ : b = c) : a = c := by
a = b with [h₁]
_ = c with [h₂]

set_option trace.egg true in
example (h₁ : 0 = 0 → a = b) : a = b := by
egg calc [h₁, (rfl : 0 = 0)]
_ = _
Expand Down
29 changes: 29 additions & 0 deletions Lean/Egg/Tests/Cond Valid MVar Dirs.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import Egg
open scoped Egg

-- This test ensures that condition mvars are correctly taken into account when determining the
-- valid direction of rewrites.

/--
info: [egg.rewrites] Rewrites
[egg.rewrites] Basic (1)
[egg.rewrites] #0(⇔): h
[egg.rewrites] ?b = ?a
[egg.rewrites] Conditions
[egg.rewrites] ?a = ?b
[egg.rewrites] LHS MVars
[?b: [.unconditionallyVisible]]
[egg.rewrites] RHS MVars
[?a: [.unconditionallyVisible]]
[egg.rewrites] Tagged (0)
[egg.rewrites] Builtin (0)
[egg.rewrites] Derived (0)
[egg.rewrites] Definitional
[egg.rewrites] Pruned (0)
-/
#guard_msgs(info) in
set_option trace.egg.rewrites true in
set_option egg.builtins false in
egg_no_defeq in
example (h : ∀ a b : Nat, a = b → b = a) : true = true := by
egg [h]
2 changes: 1 addition & 1 deletion Lean/Egg/Tests/NatLit.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ elab "app" n:num fn:ident arg:term : term => open Lean.Elab.Term in do
let rec go (n : Nat) := if n = 0 then elabTerm arg none else return .app fn <| ← go (n - 1)
go n.getNat

example : (app 100 Nat.succ (nat_lit 0)) = (nat_lit 100) := by egg
example : (app 80 Nat.succ (nat_lit 0)) = (nat_lit 80) := by egg

-- Note: This produces a gigantic proof.
example (f : Nat → Nat) (h : ∀ x, f x = x.succ) : 30 = app 30 f 0 := by
Expand Down
Loading

0 comments on commit 47dc59a

Please sign in to comment.