Skip to content

Commit

Permalink
Fix handling of ambient level mvars
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusrossel committed Apr 26, 2024
1 parent 3140bfb commit 3f95cae
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 50 deletions.
26 changes: 13 additions & 13 deletions Lean/Egg/Core/Encode/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@ private def Expression.erased : Expression :=

open EncodeM

private def encodeLevel (src : Source) : Level → EncodeM Expression
private def encodeLevel : Level → EncodeM Expression
| .zero => return "0"
| .succ l => return s!"(succ {← encodeLevel src l})"
| .max l₁ l₂ => return s!"(max {← encodeLevel src l₁} {← encodeLevel src l₂})"
| .imax l₁ l₂ => return s!"(imax {← encodeLevel src l₁} {← encodeLevel src l₂})"
-- TODO: This should check whether the level mvar is ambient, just as for expression mvars.
-- Once we do this, we can also remove `exprSrc` from the monad state.
| .mvar id => return if src.isRewrite then s!"?{id.uniqueIdx!}" else s!"(uvar {id.uniqueIdx!})"
| .succ l => return s!"(succ {← encodeLevel l})"
| .max l₁ l₂ => return s!"(max {← encodeLevel l₁} {← encodeLevel l₂})"
| .imax l₁ l₂ => return s!"(imax {← encodeLevel l₁} {← encodeLevel l₂})"
| .param name => return s!"(param {name})"
| .mvar id => do
if (← isAmbientLvl id)
then return s!"(uvar {id.uniqueIdx!})"
else return s!"?{id.uniqueIdx!}"

-- Note: This function expects its input expression to be normalized (cf. `Egg.normalize`).
partial def encode (e : Expr) (src : Source) (cfg : Config.Encoding) (amb : MVars.Ambient) :
MetaM Expression :=
Prod.fst <$> (go e).run { exprSrc := src, config := cfg, amb }
partial def encode (e : Expr) (cfg : Config.Encoding) (amb : MVars.Ambient) : MetaM Expression :=
Prod.fst <$> (go e).run { config := cfg, amb }
where
go (e : Expr) : EncodeM Expression := do
if ← needsProofErasure e then return Expression.erased else core e
Expand All @@ -37,7 +37,7 @@ where
| .bvar idx => return s!"(bvar {idx})"
| .fvar id => encodeFVar id
| .mvar id => encodeMVar id
| .sort lvl => return s!"(sort {← encodeLevel (← exprSrc) lvl})"
| .sort lvl => return s!"(sort {← encodeLevel lvl})"
| .const name lvls => return s!"(const {name}{← encodeConstLvls lvls})"
| .app fn arg => return s!"(app {← go fn} {← go arg})"
| .lam _ ty b _ => encodeLam ty b
Expand All @@ -52,12 +52,12 @@ where
else return s!"(fvar {id.uniqueIdx!})"

encodeMVar (id : MVarId) : EncodeM Expression := do
if (← isAmbient id)
if (← isAmbientExpr id)
then return s!"(mvar {id.uniqueIdx!})"
else return s!"?{id.uniqueIdx!}"

encodeConstLvls (lvls : List Level) : EncodeM Expression :=
lvls.foldlM (init := "") (return s!"{·} {← encodeLevel (← exprSrc) ·}")
lvls.foldlM (init := "") (return s!"{·} {← encodeLevel ·}")

encodeLam (ty b : Expr) : EncodeM Expression := do
let dom ← if (← config).eraseLambdaDomains then pure Expression.erased else go ty
Expand Down
11 changes: 5 additions & 6 deletions Lean/Egg/Core/Encode/EncodeM.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ open Lean
namespace Egg

structure EncodeM.State where
exprSrc : Source
config : Config.Encoding
bvars : List FVarId := []
amb : MVars.Ambient
Expand All @@ -17,14 +16,14 @@ abbrev EncodeM := StateT EncodeM.State MetaM

namespace EncodeM

def exprSrc : EncodeM Source :=
State.exprSrc <$> get

def config : EncodeM Config.Encoding :=
State.config <$> get

def isAmbient (mvar : MVarId) : EncodeM Bool := do
return (← get).amb.contains mvar
def isAmbientExpr (mvar : MVarId) : EncodeM Bool := do
return (← get).amb.expr.contains mvar

def isAmbientLvl (lmvar : LMVarId) : EncodeM Bool := do
return (← get).amb.lvl.contains lmvar

-- Note: This only works as intended if `m` does not add any additional bvars (permanently).
def withInstantiatedBVar (ty body : Expr) (m : Expr → EncodeM α) : EncodeM α := do
Expand Down
2 changes: 1 addition & 1 deletion Lean/Egg/Core/Encode/Facts.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ abbrev Facts.Encoded := Array Fact.Encoded

def Facts.encode (facts : Facts) (cfg : Config.Encoding) (amb : MVars.Ambient) :
MetaM Facts.Encoded :=
facts.mapM fun fact => Egg.encode fact.type fact.src cfg amb
facts.mapM fun fact => Egg.encode fact.type cfg amb
2 changes: 1 addition & 1 deletion Lean/Egg/Core/Encode/Guides.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ abbrev Guides.Encoded := Array Guide.Encoded

def Guides.encode (guides : Guides) (cfg : Config.Encoding) (amb : MVars.Ambient) :
MetaM Guides.Encoded :=
guides.mapM fun guide => Egg.encode guide.expr guide.src cfg amb
guides.mapM fun guide => Egg.encode guide.expr cfg amb
6 changes: 3 additions & 3 deletions Lean/Egg/Core/Encode/Rewrites.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ structure Rewrite.Encoded where
def Rewrite.encode (rw : Rewrite) (cfg : Config.Encoding) (amb : MVars.Ambient) : MetaM Encoded :=
return {
name := rw.src.description
lhs := ← Egg.encode rw.lhs rw.src cfg amb
rhs := ← Egg.encode rw.rhs rw.src cfg amb
lhs := ← Egg.encode rw.lhs cfg amb
rhs := ← Egg.encode rw.rhs cfg amb
dirs := rw.validDirs
conds := ← rw.conds.mapM fun cond => do Egg.encode cond.type rw.src cfg amb
conds := ← rw.conds.mapM fun cond => do Egg.encode cond.type cfg amb
}

namespace Rewrites
Expand Down
22 changes: 15 additions & 7 deletions Lean/Egg/Core/MVars/Ambient.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,28 @@ open Lean

namespace Egg.MVars

abbrev Ambient := PersistentHashMap MVarId MetavarDecl
abbrev Ambient.Expr := PersistentHashMap MVarId MetavarDecl

abbrev Ambient.Level := LMVarIdSet

structure Ambient where
expr : Ambient.Expr
lvl : Ambient.Level

namespace Ambient

def get : MetaM Ambient :=
def Expr.get : MetaM Ambient.Expr :=
return (← getMCtx).decls

def unassigned (amb : Ambient) : MetaM MVarIdSet :=
amb.foldlM (init := ∅) fun res mvar _ =>
def unassigned (amb : Ambient) : MetaM (MVarIdSet × LMVarIdSet) := do
let expr ← amb.expr.foldlM (init := ∅) fun res mvar _ =>
return if !(← mvar.isAssigned) then res.insert mvar else res
let lvl ← amb.lvl.filterM fun lmvar => return !(← isLevelMVarAssigned lmvar)
return (expr, lvl)

end Ambient

def remove (mvars : MVars) (amb : Ambient) : MVars where
expr := mvars.expr.filter (!amb.contains ·)
lvl := mvars.lvl
tc := mvars.tc.filter (!amb.contains ·)
expr := mvars.expr.filter (!amb.expr.contains ·)
lvl := mvars.lvl.filter (!amb.lvl.contains ·)
tc := mvars.tc.filter (!amb.expr.contains ·)
4 changes: 2 additions & 2 deletions Lean/Egg/Core/Request.lean
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def encoding
(goal : Congr) (rws : Rewrites) (facts : Facts) (guides : Guides) (cfg : Egg.Config) (amb : MVars.Ambient) :
MetaM Request := do
return {
lhs := ← encode goal.lhs .goal cfg amb
rhs := ← encode goal.rhs .goal cfg amb
lhs := ← encode goal.lhs cfg amb
rhs := ← encode goal.rhs cfg amb
rws := ← rws.encode cfg amb
facts := ← facts.encode cfg amb
guides := ← guides.encode cfg amb
Expand Down
26 changes: 19 additions & 7 deletions Lean/Egg/Tactic/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ where
else
throwError "expected goal to be of type '=' or '↔', but found:\n{← ppExpr goalType}"

-- TODO: We should also consider the level mvars of all `Fact`s.
private def collectAmbientMVars (goal : Goal) (guides : Guides) : MetaM MVars.Ambient := do
let expr ← MVars.Ambient.Expr.get
let goalLvl := (← MVars.collect (← goal.type.expr)).lvl
let guidesLvl ← guides.foldlM (init := ∅) fun res g => return res.merge (← MVars.collect g.expr).lvl
return { expr, lvl := goalLvl.merge guidesLvl }

private def tracePremises (ps : Premises) (tc : Rewrites) (cfg : Config.Gen) : TacticM Unit := do
let cls := `egg.rewrites
withTraceNode cls (fun _ => return "Rewrites") do
Expand Down Expand Up @@ -106,8 +113,13 @@ private def processRawExpl
return prf
where
catchLooseMVars (prf : Expr) (amb : MVars.Ambient) : MetaM Unit := do
for mvar in (prf.collectMVars {}).result do
unless amb.contains mvar do throwError m!"egg: final proof contains mvar {Expr.mvar mvar}"
let mvars ← MVars.collect prf
for mvar in mvars.expr do
unless amb.expr.contains mvar do
throwError m!"egg: final proof contains expression mvar {Expr.mvar mvar}"
for lmvar in mvars.lvl do
unless amb.lvl.contains lmvar do
throwError m!"egg: final proof contains level mvar {Level.mvar lmvar}"

private def traceRequest (req : Request) : TacticM Unit := do
let cls := `egg.encoded
Expand All @@ -122,14 +134,14 @@ elab "egg " mod:egg_cfg_mod rws:egg_prems base:(egg_base)? guides:(egg_guides)?
let cfg := (← Config.fromOptions).modify mod
cfg.trace `egg.config
goal.withContext do
let amb ← MVars.Ambient.get
let goal ← parseGoal goal base
let guides := (← guides.mapM Guides.parseGuides).getD #[]
let amb ← collectAmbientMVars goal guides
amb.trace `egg.ambient
-- We increase the mvar context depth, so that ambient mvars aren't unified during proof
-- reconstruction. Note that this also means that we can't assign the `goal` mvar within this
-- do-block.
let proof? ← withNewMCtxDepth do
let goal ← parseGoal goal base
let guides := (← guides.mapM Guides.parseGuides).getD #[]
let (rws, facts) ← genPremises goal rws guides cfg amb
let req ← Request.encoding goal.type rws facts guides cfg amb
traceRequest req
Expand All @@ -140,8 +152,8 @@ elab "egg " mod:egg_cfg_mod rws:egg_prems base:(egg_base)? guides:(egg_guides)?
if let .beforeProof := cfg.exitPoint then return none
return some (← processRawExpl rawExpl goal rws facts amb)
if let some proof := proof?
then goal.assignIfDefeq proof
else goal.admit
then goal.id.assignIfDefeq proof
else goal.id.admit

-- WORKAROUND: This fixes `Tests/EndOfInput *`.
macro "egg" mod:egg_cfg_mod : tactic => `(tactic| egg $mod)
9 changes: 7 additions & 2 deletions Lean/Egg/Tactic/Trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -130,5 +130,10 @@ where

nonrec def MVars.Ambient.trace (amb : MVars.Ambient) (cls : Name) : TacticM Unit := do
withTraceNode cls (fun _ => return "Ambient MVars") do
for m in ← amb.unassigned do
trace cls fun _ => m!"{Expr.mvar m}"
let (exprs, lvls) ← amb.unassigned
withTraceNode cls (fun _ => return "Expression") (collapsed := false) do
for m in exprs do
trace cls fun _ => m!"{Expr.mvar m}"
withTraceNode cls (fun _ => return "Level") (collapsed := false) do
for m in lvls do
trace cls fun _ => m!"{Level.mvar m}"
File renamed without changes.
10 changes: 10 additions & 0 deletions Lean/Egg/Tests/Ambient Level MVar.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import Egg

variable (thm₁ : {α : Type} → [Add α] → (a b : α) → a + b = b + a)
variable (thm₂ : {α : Type _} → [Add α] → (a b : α) → a + b = b + a)

example {a b : Nat} : a + b = b + a := by
egg [thm₁]

example {a b : Nat} : a + b = b + a := by
egg [thm₂]
8 changes: 0 additions & 8 deletions Lean/Egg/Tests/Classes.lean

This file was deleted.

0 comments on commit 3f95cae

Please sign in to comment.