From 3f95cae3ca167686a784b6f387c4686ab2d085ac Mon Sep 17 00:00:00 2001 From: Marcus Rossel Date: Fri, 26 Apr 2024 18:05:07 +0200 Subject: [PATCH] Fix handling of ambient level mvars --- Lean/Egg/Core/Encode/Basic.lean | 26 +++++++++---------- Lean/Egg/Core/Encode/EncodeM.lean | 11 ++++---- Lean/Egg/Core/Encode/Facts.lean | 2 +- Lean/Egg/Core/Encode/Guides.lean | 2 +- Lean/Egg/Core/Encode/Rewrites.lean | 6 ++--- Lean/Egg/Core/MVars/Ambient.lean | 22 +++++++++++----- Lean/Egg/Core/Request.lean | 4 +-- Lean/Egg/Tactic/Basic.lean | 26 ++++++++++++++----- Lean/Egg/Tactic/Trace.lean | 9 +++++-- ...bient MVar.lean => Ambient Expr MVar.lean} | 0 Lean/Egg/Tests/Ambient Level MVar.lean | 10 +++++++ Lean/Egg/Tests/Classes.lean | 8 ------ 12 files changed, 76 insertions(+), 50 deletions(-) rename Lean/Egg/Tests/{Ambient MVar.lean => Ambient Expr MVar.lean} (100%) create mode 100644 Lean/Egg/Tests/Ambient Level MVar.lean delete mode 100644 Lean/Egg/Tests/Classes.lean diff --git a/Lean/Egg/Core/Encode/Basic.lean b/Lean/Egg/Core/Encode/Basic.lean index 12ef8ef..2b4c40c 100644 --- a/Lean/Egg/Core/Encode/Basic.lean +++ b/Lean/Egg/Core/Encode/Basic.lean @@ -15,20 +15,20 @@ private def Expression.erased : Expression := open EncodeM -private def encodeLevel (src : Source) : Level → EncodeM Expression +private def encodeLevel : Level → EncodeM Expression | .zero => return "0" - | .succ l => return s!"(succ {← encodeLevel src l})" - | .max l₁ l₂ => return s!"(max {← encodeLevel src l₁} {← encodeLevel src l₂})" - | .imax l₁ l₂ => return s!"(imax {← encodeLevel src l₁} {← encodeLevel src l₂})" - -- TODO: This should check whether the level mvar is ambient, just as for expression mvars. - -- Once we do this, we can also remove `exprSrc` from the monad state. - | .mvar id => return if src.isRewrite then s!"?{id.uniqueIdx!}" else s!"(uvar {id.uniqueIdx!})" + | .succ l => return s!"(succ {← encodeLevel l})" + | .max l₁ l₂ => return s!"(max {← encodeLevel l₁} {← encodeLevel l₂})" + | .imax l₁ l₂ => return s!"(imax {← encodeLevel l₁} {← encodeLevel l₂})" | .param name => return s!"(param {name})" + | .mvar id => do + if (← isAmbientLvl id) + then return s!"(uvar {id.uniqueIdx!})" + else return s!"?{id.uniqueIdx!}" -- Note: This function expects its input expression to be normalized (cf. `Egg.normalize`). -partial def encode (e : Expr) (src : Source) (cfg : Config.Encoding) (amb : MVars.Ambient) : - MetaM Expression := - Prod.fst <$> (go e).run { exprSrc := src, config := cfg, amb } +partial def encode (e : Expr) (cfg : Config.Encoding) (amb : MVars.Ambient) : MetaM Expression := + Prod.fst <$> (go e).run { config := cfg, amb } where go (e : Expr) : EncodeM Expression := do if ← needsProofErasure e then return Expression.erased else core e @@ -37,7 +37,7 @@ where | .bvar idx => return s!"(bvar {idx})" | .fvar id => encodeFVar id | .mvar id => encodeMVar id - | .sort lvl => return s!"(sort {← encodeLevel (← exprSrc) lvl})" + | .sort lvl => return s!"(sort {← encodeLevel lvl})" | .const name lvls => return s!"(const {name}{← encodeConstLvls lvls})" | .app fn arg => return s!"(app {← go fn} {← go arg})" | .lam _ ty b _ => encodeLam ty b @@ -52,12 +52,12 @@ where else return s!"(fvar {id.uniqueIdx!})" encodeMVar (id : MVarId) : EncodeM Expression := do - if (← isAmbient id) + if (← isAmbientExpr id) then return s!"(mvar {id.uniqueIdx!})" else return s!"?{id.uniqueIdx!}" encodeConstLvls (lvls : List Level) : EncodeM Expression := - lvls.foldlM (init := "") (return s!"{·} {← encodeLevel (← exprSrc) ·}") + lvls.foldlM (init := "") (return s!"{·} {← encodeLevel ·}") encodeLam (ty b : Expr) : EncodeM Expression := do let dom ← if (← config).eraseLambdaDomains then pure Expression.erased else go ty diff --git a/Lean/Egg/Core/Encode/EncodeM.lean b/Lean/Egg/Core/Encode/EncodeM.lean index 7ede787..82ef4c4 100644 --- a/Lean/Egg/Core/Encode/EncodeM.lean +++ b/Lean/Egg/Core/Encode/EncodeM.lean @@ -8,7 +8,6 @@ open Lean namespace Egg structure EncodeM.State where - exprSrc : Source config : Config.Encoding bvars : List FVarId := [] amb : MVars.Ambient @@ -17,14 +16,14 @@ abbrev EncodeM := StateT EncodeM.State MetaM namespace EncodeM -def exprSrc : EncodeM Source := - State.exprSrc <$> get - def config : EncodeM Config.Encoding := State.config <$> get -def isAmbient (mvar : MVarId) : EncodeM Bool := do - return (← get).amb.contains mvar +def isAmbientExpr (mvar : MVarId) : EncodeM Bool := do + return (← get).amb.expr.contains mvar + +def isAmbientLvl (lmvar : LMVarId) : EncodeM Bool := do + return (← get).amb.lvl.contains lmvar -- Note: This only works as intended if `m` does not add any additional bvars (permanently). def withInstantiatedBVar (ty body : Expr) (m : Expr → EncodeM α) : EncodeM α := do diff --git a/Lean/Egg/Core/Encode/Facts.lean b/Lean/Egg/Core/Encode/Facts.lean index dac385a..4ed9de7 100644 --- a/Lean/Egg/Core/Encode/Facts.lean +++ b/Lean/Egg/Core/Encode/Facts.lean @@ -11,4 +11,4 @@ abbrev Facts.Encoded := Array Fact.Encoded def Facts.encode (facts : Facts) (cfg : Config.Encoding) (amb : MVars.Ambient) : MetaM Facts.Encoded := - facts.mapM fun fact => Egg.encode fact.type fact.src cfg amb + facts.mapM fun fact => Egg.encode fact.type cfg amb diff --git a/Lean/Egg/Core/Encode/Guides.lean b/Lean/Egg/Core/Encode/Guides.lean index d6bdf6c..3581025 100644 --- a/Lean/Egg/Core/Encode/Guides.lean +++ b/Lean/Egg/Core/Encode/Guides.lean @@ -11,4 +11,4 @@ abbrev Guides.Encoded := Array Guide.Encoded def Guides.encode (guides : Guides) (cfg : Config.Encoding) (amb : MVars.Ambient) : MetaM Guides.Encoded := - guides.mapM fun guide => Egg.encode guide.expr guide.src cfg amb + guides.mapM fun guide => Egg.encode guide.expr cfg amb diff --git a/Lean/Egg/Core/Encode/Rewrites.lean b/Lean/Egg/Core/Encode/Rewrites.lean index f41fbef..f0815c8 100644 --- a/Lean/Egg/Core/Encode/Rewrites.lean +++ b/Lean/Egg/Core/Encode/Rewrites.lean @@ -16,10 +16,10 @@ structure Rewrite.Encoded where def Rewrite.encode (rw : Rewrite) (cfg : Config.Encoding) (amb : MVars.Ambient) : MetaM Encoded := return { name := rw.src.description - lhs := ← Egg.encode rw.lhs rw.src cfg amb - rhs := ← Egg.encode rw.rhs rw.src cfg amb + lhs := ← Egg.encode rw.lhs cfg amb + rhs := ← Egg.encode rw.rhs cfg amb dirs := rw.validDirs - conds := ← rw.conds.mapM fun cond => do Egg.encode cond.type rw.src cfg amb + conds := ← rw.conds.mapM fun cond => do Egg.encode cond.type cfg amb } namespace Rewrites diff --git a/Lean/Egg/Core/MVars/Ambient.lean b/Lean/Egg/Core/MVars/Ambient.lean index 400fe27..4b4287e 100644 --- a/Lean/Egg/Core/MVars/Ambient.lean +++ b/Lean/Egg/Core/MVars/Ambient.lean @@ -4,20 +4,28 @@ open Lean namespace Egg.MVars -abbrev Ambient := PersistentHashMap MVarId MetavarDecl +abbrev Ambient.Expr := PersistentHashMap MVarId MetavarDecl + +abbrev Ambient.Level := LMVarIdSet + +structure Ambient where + expr : Ambient.Expr + lvl : Ambient.Level namespace Ambient -def get : MetaM Ambient := +def Expr.get : MetaM Ambient.Expr := return (← getMCtx).decls -def unassigned (amb : Ambient) : MetaM MVarIdSet := - amb.foldlM (init := ∅) fun res mvar _ => +def unassigned (amb : Ambient) : MetaM (MVarIdSet × LMVarIdSet) := do + let expr ← amb.expr.foldlM (init := ∅) fun res mvar _ => return if !(← mvar.isAssigned) then res.insert mvar else res + let lvl ← amb.lvl.filterM fun lmvar => return !(← isLevelMVarAssigned lmvar) + return (expr, lvl) end Ambient def remove (mvars : MVars) (amb : Ambient) : MVars where - expr := mvars.expr.filter (!amb.contains ·) - lvl := mvars.lvl - tc := mvars.tc.filter (!amb.contains ·) + expr := mvars.expr.filter (!amb.expr.contains ·) + lvl := mvars.lvl.filter (!amb.lvl.contains ·) + tc := mvars.tc.filter (!amb.expr.contains ·) diff --git a/Lean/Egg/Core/Request.lean b/Lean/Egg/Core/Request.lean index 7c02405..3aa3dfe 100644 --- a/Lean/Egg/Core/Request.lean +++ b/Lean/Egg/Core/Request.lean @@ -45,8 +45,8 @@ def encoding (goal : Congr) (rws : Rewrites) (facts : Facts) (guides : Guides) (cfg : Egg.Config) (amb : MVars.Ambient) : MetaM Request := do return { - lhs := ← encode goal.lhs .goal cfg amb - rhs := ← encode goal.rhs .goal cfg amb + lhs := ← encode goal.lhs cfg amb + rhs := ← encode goal.rhs cfg amb rws := ← rws.encode cfg amb facts := ← facts.encode cfg amb guides := ← guides.encode cfg amb diff --git a/Lean/Egg/Tactic/Basic.lean b/Lean/Egg/Tactic/Basic.lean index cd3e953..2eaee80 100644 --- a/Lean/Egg/Tactic/Basic.lean +++ b/Lean/Egg/Tactic/Basic.lean @@ -42,6 +42,13 @@ where else throwError "expected goal to be of type '=' or '↔', but found:\n{← ppExpr goalType}" +-- TODO: We should also consider the level mvars of all `Fact`s. +private def collectAmbientMVars (goal : Goal) (guides : Guides) : MetaM MVars.Ambient := do + let expr ← MVars.Ambient.Expr.get + let goalLvl := (← MVars.collect (← goal.type.expr)).lvl + let guidesLvl ← guides.foldlM (init := ∅) fun res g => return res.merge (← MVars.collect g.expr).lvl + return { expr, lvl := goalLvl.merge guidesLvl } + private def tracePremises (ps : Premises) (tc : Rewrites) (cfg : Config.Gen) : TacticM Unit := do let cls := `egg.rewrites withTraceNode cls (fun _ => return "Rewrites") do @@ -106,8 +113,13 @@ private def processRawExpl return prf where catchLooseMVars (prf : Expr) (amb : MVars.Ambient) : MetaM Unit := do - for mvar in (prf.collectMVars {}).result do - unless amb.contains mvar do throwError m!"egg: final proof contains mvar {Expr.mvar mvar}" + let mvars ← MVars.collect prf + for mvar in mvars.expr do + unless amb.expr.contains mvar do + throwError m!"egg: final proof contains expression mvar {Expr.mvar mvar}" + for lmvar in mvars.lvl do + unless amb.lvl.contains lmvar do + throwError m!"egg: final proof contains level mvar {Level.mvar lmvar}" private def traceRequest (req : Request) : TacticM Unit := do let cls := `egg.encoded @@ -122,14 +134,14 @@ elab "egg " mod:egg_cfg_mod rws:egg_prems base:(egg_base)? guides:(egg_guides)? let cfg := (← Config.fromOptions).modify mod cfg.trace `egg.config goal.withContext do - let amb ← MVars.Ambient.get + let goal ← parseGoal goal base + let guides := (← guides.mapM Guides.parseGuides).getD #[] + let amb ← collectAmbientMVars goal guides amb.trace `egg.ambient -- We increase the mvar context depth, so that ambient mvars aren't unified during proof -- reconstruction. Note that this also means that we can't assign the `goal` mvar within this -- do-block. let proof? ← withNewMCtxDepth do - let goal ← parseGoal goal base - let guides := (← guides.mapM Guides.parseGuides).getD #[] let (rws, facts) ← genPremises goal rws guides cfg amb let req ← Request.encoding goal.type rws facts guides cfg amb traceRequest req @@ -140,8 +152,8 @@ elab "egg " mod:egg_cfg_mod rws:egg_prems base:(egg_base)? guides:(egg_guides)? if let .beforeProof := cfg.exitPoint then return none return some (← processRawExpl rawExpl goal rws facts amb) if let some proof := proof? - then goal.assignIfDefeq proof - else goal.admit + then goal.id.assignIfDefeq proof + else goal.id.admit -- WORKAROUND: This fixes `Tests/EndOfInput *`. macro "egg" mod:egg_cfg_mod : tactic => `(tactic| egg $mod) diff --git a/Lean/Egg/Tactic/Trace.lean b/Lean/Egg/Tactic/Trace.lean index 10377b9..e4c347b 100644 --- a/Lean/Egg/Tactic/Trace.lean +++ b/Lean/Egg/Tactic/Trace.lean @@ -130,5 +130,10 @@ where nonrec def MVars.Ambient.trace (amb : MVars.Ambient) (cls : Name) : TacticM Unit := do withTraceNode cls (fun _ => return "Ambient MVars") do - for m in ← amb.unassigned do - trace cls fun _ => m!"{Expr.mvar m}" + let (exprs, lvls) ← amb.unassigned + withTraceNode cls (fun _ => return "Expression") (collapsed := false) do + for m in exprs do + trace cls fun _ => m!"{Expr.mvar m}" + withTraceNode cls (fun _ => return "Level") (collapsed := false) do + for m in lvls do + trace cls fun _ => m!"{Level.mvar m}" diff --git a/Lean/Egg/Tests/Ambient MVar.lean b/Lean/Egg/Tests/Ambient Expr MVar.lean similarity index 100% rename from Lean/Egg/Tests/Ambient MVar.lean rename to Lean/Egg/Tests/Ambient Expr MVar.lean diff --git a/Lean/Egg/Tests/Ambient Level MVar.lean b/Lean/Egg/Tests/Ambient Level MVar.lean new file mode 100644 index 0000000..6e970e4 --- /dev/null +++ b/Lean/Egg/Tests/Ambient Level MVar.lean @@ -0,0 +1,10 @@ +import Egg + +variable (thm₁ : {α : Type} → [Add α] → (a b : α) → a + b = b + a) +variable (thm₂ : {α : Type _} → [Add α] → (a b : α) → a + b = b + a) + +example {a b : Nat} : a + b = b + a := by + egg [thm₁] + +example {a b : Nat} : a + b = b + a := by + egg [thm₂] diff --git a/Lean/Egg/Tests/Classes.lean b/Lean/Egg/Tests/Classes.lean deleted file mode 100644 index 14366de..0000000 --- a/Lean/Egg/Tests/Classes.lean +++ /dev/null @@ -1,8 +0,0 @@ -import Egg - --- This tests how type class arguments are encoded in rewrites and matched against expressions. - -variable (thm : {α : Type _} → [Add α] → (a b : α) → a + b = b + a) - -example {a b : Nat} : a + b = b + a := by - egg [thm]