Skip to content

Commit

Permalink
Undo sparse explanations introduced in 7ba9316
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusrossel committed Apr 16, 2024
1 parent dc1add6 commit 2b09600
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 46 deletions.
2 changes: 1 addition & 1 deletion Lean/Egg/Core/Explanation/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ inductive Expression where
deriving Inhabited

structure Step extends Rewrite.Info where
dst : Option Expression
dst : Expression
deriving Inhabited

end Explanation
Expand Down
63 changes: 19 additions & 44 deletions Lean/Egg/Core/Explanation/Proof.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ import Egg.Core.Explanation.Congr
import Egg.Core.Rewrites
open Lean Meta

-- TODO: Simplify tracing by adding `MessageData` instances for relevant types.

namespace Egg.Explanation

private partial def replaceSubexprs
Expand Down Expand Up @@ -100,17 +98,17 @@ def proof (expl : Explanation) (cgr : Congr) (rws : Rewrites) (amb : AmbientMVar
for step in steps, idx in [:steps.size] do
withTraceNode `egg.reconstruction (fun _ => return s!"Step {idx}") do
trace[egg.reconstruction] step.src.description
trace[egg.reconstruction] ← step.dst.mapM (·.toExpr)
trace[egg.reconstruction] ← step.dst.toExpr
unless ← isDefEq cgr.lhs current do
throwError s!"{errorPrefix} initial expression is not defeq to lhs of proof goal"
let mut proof ← mkEqRefl current
for step in steps, idx in [:steps.size] do
let next? ← step.dst.mapM (·.toExpr)
let (stepEq, next)do
let next ← step.dst.toExpr
let stepEq ← do
withTraceNode `egg.reconstruction (fun _ => return m!"Step {idx}") do
trace[egg.reconstruction] m!"Current: {current}"
trace[egg.reconstruction] m!"Next: {next?}"
proofStep current next? step.toInfo
trace[egg.reconstruction] m!"Next: {next}"
proofStep current next step.toInfo
proof ← mkEqTrans proof stepEq
current := next
checkFinalProof proof current steps
Expand All @@ -120,63 +118,40 @@ def proof (expl : Explanation) (cgr : Congr) (rws : Rewrites) (amb : AmbientMVar
where
errorPrefix := "egg failed to reconstruct proof:"

proofStep (current : Expr) (next? : Option Expr) (rwInfo : Rewrite.Info) :
MetaM (Expr × Expr) := do
if rwInfo.src.isDefEq then
if let some next := next?
then return (← mkReflStep current next rwInfo.src, next)
else throwError s!"{errorPrefix} defeq steps in sparse explanations aren't supported yet"
proofStep (current next : Expr) (rwInfo : Rewrite.Info) : MetaM Expr := do
if rwInfo.src.isDefEq then return ← mkReflStep current next rwInfo.src
let some rw := rws.find? rwInfo.src | throwError s!"{errorPrefix} unknown rewrite"
if ← isRflProof rw.proof then
if let some next := next? then
return (← mkReflStep current next rwInfo.src, next)
mkCongrStep current next? rwInfo.pos (← rw.forDir rwInfo.dir)
if ← isRflProof rw.proof then return ← mkReflStep current next rwInfo.src
mkCongrStep current next rwInfo.pos (← rw.forDir rwInfo.dir)

mkReflStep (current next : Expr) (src : Source) : MetaM Expr := do
unless ← isDefEq current next do
throwError s!"{errorPrefix} unification failure for proof by reflexivity with rw {src.description}"
mkEqRefl next

mkCongrStep (current : Expr) (next? : Option Expr) (pos : SubExpr.Pos) (rw : Rewrite) :
MetaM (Expr × Expr) := do
mkCongrStep (current next : Expr) (pos : SubExpr.Pos) (rw : Rewrite) : MetaM Expr := do
let mvc := (← getMCtx).mvarCounter
let (lhs, rhs) ← placeCHoles current next? pos rw
try
let proof ← (← mkCongrOf 0 mvc lhs rhs).eq
let next := next?.getD (← unwrapHole rhs pos)
return (proof, next)
catch err =>
throwError m!"{errorPrefix} 'mkCongrOf' failed with\n {err.toMessageData}"

unwrapHole (expr : Expr) (pos : SubExpr.Pos) : MetaM Expr :=
replaceSubexpr (root := expr) (p := pos) fun h =>
if let some (_, val, _) := cHole? h
then return val
else throwError "{errorPrefix} expected to find congr-hole but didn't"
let (lhs, rhs) ← placeCHoles current next pos rw
try (← mkCongrOf 0 mvc lhs rhs).eq
catch err => throwError m!"{errorPrefix} 'mkCongrOf' failed with\n {err.toMessageData}"

placeCHoles (current : Expr) (next? : Option Expr) (pos : SubExpr.Pos) (rw : Rewrite) :
MetaM (Expr × Expr) := do
replaceSubexprs (root₁ := current) (root₂ := next?.getD current) (p := pos) fun lhs rhs => do
placeCHoles (current next : Expr) (pos : SubExpr.Pos) (rw : Rewrite) : MetaM (Expr × Expr) := do
replaceSubexprs (root₁ := current) (root₂ := next) (p := pos) fun lhs rhs => do
-- It's necessary that we create the fresh rewrite (that is, create the fresh mvars) in *this*
-- local context as otherwise the mvars can't unify with variables under binders.
let rw ← rw.fresh
unless ← isDefEq lhs rw.lhs do throwError "{errorPrefix} unification failure for LHS of rewrite"
let rhs ←
if next?.isSome then
unless ← isDefEq rhs rw.rhs do throwError "{errorPrefix} unification failure for RHS of rewrite"
pure rhs
else
pure rw.rhs
unless ← isDefEq rhs rw.rhs do throwError "{errorPrefix} unification failure for RHS of rewrite"
let proof ← rw.eqProof
return (
← mkCHole (forLhs := true) lhs proof,
← mkCHole (forLhs := false) rhs proof
)

checkFinalProof (proof : Expr) (current : Expr) (steps : Array Step) : MetaM Unit := do
if let some last := steps.back?.map (·.dst) |>.getD (some expl.start) then
unless ← isDefEq current (← last.toExpr) do
throwError s!"{errorPrefix} final expression is not defeq to rhs of proof goal"
let last := steps.back?.map (·.dst) |>.getD expl.start
unless ← isDefEq current (← last.toExpr) do
throwError s!"{errorPrefix} final expression is not defeq to rhs of proof goal"
let proof ← instantiateMVars proof
for mvar in (proof.collectMVars {}).result do
unless amb.contains mvar do throwError s!"{errorPrefix} final proof contains mvar {mvar.name}"
2 changes: 1 addition & 1 deletion Lean/Egg/Tests/Int.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import Egg
import Std
import Std.Data.Int.Lemmas

example (a b c d : Int) : ((a * b) - (2 * c)) * d - (a * b) = (d - 1) * (a * b) - (2 * c * d) := by
egg [Int.sub_mul, Int.sub_sub, Int.add_comm, Int.mul_comm, Int.one_mul]

0 comments on commit 2b09600

Please sign in to comment.