From 39a112877f9d83be7112018732a13433259b4380 Mon Sep 17 00:00:00 2001 From: Marcus Rossel Date: Fri, 12 Apr 2024 13:58:04 +0200 Subject: [PATCH] Implement type class specialization --- C/ffi.c | 4 +- Lean/Egg.lean | 3 +- Lean/Egg/Core/Config.lean | 2 +- Lean/Egg/Core/{Rewrites => }/Directions.lean | 12 +- Lean/Egg/Core/Encode/EncodeM.lean | 1 - Lean/Egg/Core/Encode/Rewrites.lean | 2 +- Lean/Egg/Core/Explanation/Basic.lean | 3 +- Lean/Egg/Core/Gen/Explosion.lean | 40 ------- Lean/Egg/Core/Gen/TcSpecs.lean | 36 ++++++ Lean/Egg/Core/MVars.lean | 113 ++++++++++++++++--- Lean/Egg/Core/Request.lean | 1 - Lean/Egg/Core/Rewrites/Basic.lean | 67 ++++------- Lean/Egg/Core/Source.lean | 5 +- Lean/Egg/Lean.lean | 8 +- Lean/Egg/Tactic/Basic.lean | 4 +- Lean/Egg/Tactic/Config/Modifier.lean | 4 +- Lean/Egg/Tactic/Config/Option.lean | 10 +- Lean/Egg/Tactic/Explanation.lean | 44 +++++--- Lean/Egg/Tests/Classes.lean | 3 +- Lean/Egg/Tests/Groups.lean | 4 +- Lean/Egg/Tests/mathlib4 | 2 +- Lean/Egg/Tests/run_tests.sh | 6 +- lake-manifest.json | 4 +- 23 files changed, 230 insertions(+), 148 deletions(-) rename Lean/Egg/Core/{Rewrites => }/Directions.lean (81%) delete mode 100644 Lean/Egg/Core/Gen/Explosion.lean create mode 100644 Lean/Egg/Core/Gen/TcSpecs.lean diff --git a/C/ffi.c b/C/ffi.c index 1eb37fe..9ac3431 100644 --- a/C/ffi.c +++ b/C/ffi.c @@ -20,11 +20,11 @@ structure Rewrite.Encoded where name : String lhs : Expression rhs : Expression - dirs : Rewrite.Directions + dirs : Directions abbrev Expression := String -inductive Rewrite.Directions where +inductive Directions where | none | forward | backward diff --git a/Lean/Egg.lean b/Lean/Egg.lean index 13dca04..1c2ac2c 100644 --- a/Lean/Egg.lean +++ b/Lean/Egg.lean @@ -3,10 +3,11 @@ import Egg.Core.Encode.EncodeM import Egg.Core.Explanation.Basic import Egg.Core.Explanation.Proof import Egg.Core.Gen.TcProjs +import Egg.Core.Gen.TcSpecs import Egg.Core.Rewrites.Basic -import Egg.Core.Rewrites.Directions import Egg.Core.Config import Egg.Core.Congr +import Egg.Core.Directions import Egg.Core.MVars import Egg.Core.Request import Egg.Core.Source diff --git a/Lean/Egg/Core/Config.lean b/Lean/Egg/Core/Config.lean index 0a2d51a..43013c0 100644 --- a/Lean/Egg/Core/Config.lean +++ b/Lean/Egg/Core/Config.lean @@ -10,10 +10,10 @@ structure Encoding where structure Gen where genTcProjRws := true + genTcSpecRws := true genNatLitRws := true genEtaRw := true genBetaRw := true - explode := true deriving BEq structure Backend where diff --git a/Lean/Egg/Core/Rewrites/Directions.lean b/Lean/Egg/Core/Directions.lean similarity index 81% rename from Lean/Egg/Core/Rewrites/Directions.lean rename to Lean/Egg/Core/Directions.lean index d399cd9..3b38e56 100644 --- a/Lean/Egg/Core/Rewrites/Directions.lean +++ b/Lean/Egg/Core/Directions.lean @@ -1,12 +1,16 @@ import Lean open Lean -namespace Egg.Rewrite +namespace Egg inductive Direction where | forward | backward - deriving Inhabited + deriving Inhabited, BEq, Hashable + +def Direction.description : Direction → String + | .forward => "→" + | .backward => "←" def Direction.merge : Direction → Direction → Direction | .forward, .forward | .backward, .backward => .forward @@ -28,6 +32,10 @@ instance : ToString Directions where | .backward => "backward" | .both => "both" +def contains : Directions → Direction → Bool + | .both, _ | .forward, .forward | .backward, .backward => true + | _, _ => false + -- The directions for which a given set is a superset of the other. def satisfyingSuperset (lhs rhs : RBTree α cmp) : Directions := match rhs.subset lhs, lhs.subset rhs with diff --git a/Lean/Egg/Core/Encode/EncodeM.lean b/Lean/Egg/Core/Encode/EncodeM.lean index 2b642d0..7973ebd 100644 --- a/Lean/Egg/Core/Encode/EncodeM.lean +++ b/Lean/Egg/Core/Encode/EncodeM.lean @@ -1,6 +1,5 @@ import Egg.Core.Config import Egg.Core.Source -import Egg.Core.Gen.Explosion import Std.Data.List.Basic open Lean diff --git a/Lean/Egg/Core/Encode/Rewrites.lean b/Lean/Egg/Core/Encode/Rewrites.lean index 9339a60..a298cdf 100644 --- a/Lean/Egg/Core/Encode/Rewrites.lean +++ b/Lean/Egg/Core/Encode/Rewrites.lean @@ -10,7 +10,7 @@ structure Rewrite.Encoded where name : String lhs : Expression rhs : Expression - dirs : Rewrite.Directions + dirs : Directions def Rewrite.encode (cfg : Config.Encoding) (rw : Rewrite) : MetaM Encoded := return { diff --git a/Lean/Egg/Core/Explanation/Basic.lean b/Lean/Egg/Core/Explanation/Basic.lean index cbb8114..fa8ca74 100644 --- a/Lean/Egg/Core/Explanation/Basic.lean +++ b/Lean/Egg/Core/Explanation/Basic.lean @@ -1,8 +1,7 @@ import Egg.Core.Source -import Egg.Core.Rewrites.Directions +import Egg.Core.Directions open Lean -open Egg.Rewrite (Direction) namespace Egg.Explanation diff --git a/Lean/Egg/Core/Gen/Explosion.lean b/Lean/Egg/Core/Gen/Explosion.lean deleted file mode 100644 index 3250cb3..0000000 --- a/Lean/Egg/Core/Gen/Explosion.lean +++ /dev/null @@ -1,40 +0,0 @@ -import Egg.Core.Rewrites.Basic -import Egg.Lean -import Lean -open Lean Meta - -namespace Egg - -private def Rewrite.explode (rw : Rewrite) : MetaM Rewrites := do - match rw.validDirs with - | .both => return #[] - | .forward => return #[] - | .backward => return #[] - | .none => return #[] -/-where - backwardRws (rw : Rewrite) : MetaM Rewrites := do - let unboundMVars := rw.lhsMVars.expr.subtract rw.rhsMVars.expr - let mut current : Rewrites := #[rw] - let mut next : Rewrites := #[] - for var in unboundMVars do - for cur in current do - for decl in (← getLCtx) do - let cur' ← cur.fresh -- TODO: Adjust the source. - -- if var.getType - sorry - return next - - -- Create and register an explosion rewrite from `side` to the other side. - explosionRwFromSide (side : Side) : MetaM Rewrite := do - let rw ← rw.fresh (src := .explosion rw.src 0) - let targetMVars := match side with - | .left => rw.rhsMVars.expr.subtract rw.lhsMVars.expr - | .right => rw.lhsMVars.expr.subtract rw.rhsMVars.expr - let targetLMVars := match side with - | .left => rw.rhsMVars.lvl.subtract rw.lhsMVars.lvl - | .right => rw.lhsMVars.lvl.subtract rw.rhsMVars.lvl - return rw --/ - -def Rewrites.explode (rws : Rewrites) : MetaM Rewrites := do - rws.foldlM (init := #[]) fun acc rw => return acc ++ (← rw.explode) diff --git a/Lean/Egg/Core/Gen/TcSpecs.lean b/Lean/Egg/Core/Gen/TcSpecs.lean new file mode 100644 index 0000000..b87c586 --- /dev/null +++ b/Lean/Egg/Core/Gen/TcSpecs.lean @@ -0,0 +1,36 @@ +import Egg.Core.Rewrites.Basic +import Std.Tactic.Exact +import Lean +open Lean Meta + +namespace Egg + +private partial def genSpecialization (rw : Rewrite) (dir : Direction) (missing : MVarIdSet) : + MetaM (Option Rewrite) := do + let (rw, subst) ← rw.freshWithSubst (src := .tcSpec rw.src dir) + let mut missing := missing.map subst.expr.fwd.find! + let mut changed := true + while changed do + changed := false + for var in missing do + if let some inst ← findLocalDeclWithType? (← var.getType) then + var.assign (.fvar inst) + missing := missing.erase var + changed := true + let rw ← rw.instantiateMVars + return if rw.validDirs.contains dir then rw else none + +private def genTcSpecializationsForRw (rw : Rewrite) : MetaM Rewrites := do + let missingOnLhs := rw.rhsMVars.tc.subtract rw.lhsMVars.tc + let missingOnRhs := rw.lhsMVars.tc.subtract rw.rhsMVars.tc + let mut specs : Rewrites := #[] + if !missingOnLhs.isEmpty then + if let some spec ← genSpecialization rw .forward missingOnLhs then + specs := specs.push spec + if !missingOnRhs.isEmpty then + if let some spec ← genSpecialization rw .backward missingOnRhs then + specs := specs.push spec + return specs + +def genTcSpecializations (targets : Rewrites) : MetaM Rewrites := + targets.foldlM (init := #[]) fun acc rw => return acc ++ (← genTcSpecializationsForRw rw) diff --git a/Lean/Egg/Core/MVars.lean b/Lean/Egg/Core/MVars.lean index 933ac43..ab0b34e 100644 --- a/Lean/Egg/Core/MVars.lean +++ b/Lean/Egg/Core/MVars.lean @@ -1,36 +1,44 @@ import Egg.Lean import Lean -open Lean +open Lean Meta namespace Egg structure MVars where expr : MVarIdSet := ∅ lvl : LMVarIdSet := ∅ + -- A subset of `expr` which tracks the mvars whose type is a type class. + tc : MVarIdSet := ∅ + +private def MVars.insertExpr (mvars : MVars) (id : MVarId) : MetaM MVars := do + let isClass := (← isClass? (← id.getType)).isSome + return { mvars with + expr := mvars.expr.insert id + tc := if isClass then mvars.tc.insert id else mvars.tc + } private structure MVarCollectionState where visitedExprs : ExprSet := {} visitedLvls : LevelSet := {} mvars : MVars := {} -private partial def collectMVars : Expr → MVarCollectionState → MVarCollectionState +private partial def collectMVars : Expr → MVarCollectionState → MetaM MVarCollectionState | .mvar id => visitMVar id - | .const _ lvls => visitConst lvls - | .sort lvl => visitSort lvl + | .const _ lvls => (return visitConst lvls ·) + | .sort lvl => (return visitSort lvl ·) | .proj _ _ e | .mdata _ e => visit e - | .forallE _ e₁ e₂ _ | .lam _ e₁ e₂ _ | .app e₁ e₂ => visit e₁ ∘ visit e₂ - | .letE _ e₁ e₂ e₃ _ => visit e₁ ∘ visit e₂ ∘ visit e₃ - | _ => id + | .forallE _ e₁ e₂ _ | .lam _ e₁ e₂ _ | .app e₁ e₂ => visit e₁ >=> visit e₂ + | .letE _ e₁ e₂ e₃ _ => visit e₁ >=> visit e₂ >=> visit e₃ + | _ => pure where - visit (e : Expr) (s : MVarCollectionState) : MVarCollectionState := + visit (e : Expr) (s : MVarCollectionState) : MetaM MVarCollectionState := if !e.hasMVar || s.visitedExprs.contains e then - s + return s else collectMVars e { s with visitedExprs := s.visitedExprs.insert e } - visitMVar (id : MVarId) (s : MVarCollectionState) : MVarCollectionState := { s with - mvars.expr := s.mvars.expr.insert id - } + visitMVar (id : MVarId) (s : MVarCollectionState) : MetaM MVarCollectionState := + return { s with mvars := ← s.mvars.insertExpr id } visitConst (lvls : List Level) (s : MVarCollectionState) : MVarCollectionState := Id.run do let mut s := s @@ -53,9 +61,84 @@ where visitedLvls := s.visitedLvls.insert lvl } -def MVars.collect (e : Expr) : MVars := - collectMVars e {} |>.mvars +namespace MVars + +def collect (e : Expr) : MetaM MVars := + MVarCollectionState.mvars <$> collectMVars e {} -def MVars.merge (vars₁ vars₂ : MVars) : MVars where +def merge (vars₁ vars₂ : MVars) : MVars where expr := vars₁.expr.merge vars₂.expr lvl := vars₁.lvl.merge vars₂.lvl + +protected structure Subst.Expr where + fwd : HashMap MVarId MVarId := ∅ + bwd : HashMap MVarId MVarId := ∅ + +protected abbrev Subst.Lvl := HashMap LMVarId LMVarId + +structure Subst where + expr : Subst.Expr := {} + lvl : Subst.Lvl := ∅ + +def Subst.apply (subst : Subst) (e : Expr) : Expr := + e.replace replaceExpr +where + replaceExpr : Expr → Option Expr + | .mvar id => subst.expr.fwd.find? id >>= (Expr.mvar ·) + | .sort lvl => Expr.sort <| lvl.replace replaceLvl + | .const name lvls => Expr.const name <| lvls.map (·.replace replaceLvl) + | _ => none + + replaceLvl : Level → Option Level + | .mvar id => subst.lvl.find? id >>= (Level.mvar ·) + | _ => none + +def fresh (mvars : MVars) (init : Subst := {}) : MetaM (MVars × Subst) := do + let (exprVars, exprSubst) ← freshExprs mvars.expr init.expr + let (lvlVars, lvlSubst) ← freshLvls mvars.lvl init.lvl + let subst := { expr := exprSubst, lvl := lvlSubst } + assignFreshExprMVarTypes exprVars subst + return ({ expr := exprVars, lvl := lvlVars }, subst) +where + freshExprs (src : MVarIdSet) (subst : Subst.Expr) : MetaM (MVarIdSet × Subst.Expr) := do + let mut vars : MVarIdSet := {} + let mut subst := subst + for var in src do + if let some f := subst.fwd.find? var then + vars := vars.insert f + else + -- Note: As the type of an mvar may also contain mvars, we also have to replace mvars with + -- their fresh counterpart *in the type*. We can only do this once we know the fresh + -- counterpart for each mvar, so we postpone the type assignment. + let f ← mkFreshExprMVar none + subst := { + fwd := subst.fwd.insert var f.mvarId! + bwd := subst.bwd.insert f.mvarId! var + } + vars := vars.insert f.mvarId! + return (vars, subst) + + freshLvls (src : LMVarIdSet) (subst : Subst.Lvl) : MetaM (LMVarIdSet × Subst.Lvl) := do + let mut vars : LMVarIdSet := {} + let mut subst := subst + for var in src do + if let some f := subst.find? var then + vars := vars.insert f + else + let f ← mkFreshLevelMVar + subst := subst.insert var f.mvarId! + vars := vars.insert f.mvarId! + return (vars, subst) + + assignFreshExprMVarTypes (vars : MVarIdSet) (subst : Subst) : MetaM Unit := do + for var in vars do + let srcType ← (subst.expr.bwd.find! var).getType + let freshType := subst.apply srcType + var.setType freshType + +def removeAssigned (mvars : MVars) : MetaM MVars := do + return { + expr := ← mvars.expr.filterM fun var => return !(← var.isAssigned) + lvl := ← mvars.lvl.filterM fun var => return !(← isLevelMVarAssigned var) + tc := ← mvars.tc.filterM fun var => return !(← var.isAssigned) + } diff --git a/Lean/Egg/Core/Request.lean b/Lean/Egg/Core/Request.lean index a6c1901..a46eaa1 100644 --- a/Lean/Egg/Core/Request.lean +++ b/Lean/Egg/Core/Request.lean @@ -1,6 +1,5 @@ import Egg.Core.Encode.Rewrites import Egg.Core.Config -import Egg.Core.Gen.Explosion import Egg.Core.Explanation.Basic import Egg.Core.Rewrites.Basic open Lean diff --git a/Lean/Egg/Core/Rewrites/Basic.lean b/Lean/Egg/Core/Rewrites/Basic.lean index 8f8620b..8c38ed3 100644 --- a/Lean/Egg/Core/Rewrites/Basic.lean +++ b/Lean/Egg/Core/Rewrites/Basic.lean @@ -1,4 +1,4 @@ -import Egg.Core.Rewrites.Directions +import Egg.Core.Directions import Egg.Core.MVars import Egg.Core.Normalize import Egg.Core.Congr @@ -34,8 +34,8 @@ def from? (proof : Expr) (type : Expr) (src : Source) (beta eta : Bool) : MetaM type ← normalize type beta eta let proof := mkAppN proof args let some cgr ← Congr.from? type | return none - let lhsMVars := MVars.collect cgr.lhs - let rhsMVars := MVars.collect cgr.rhs + let lhsMVars ← MVars.collect cgr.lhs + let rhsMVars ← MVars.collect cgr.rhs return some { cgr with proof, src, lhsMVars, rhsMVars } def validDirs (rw : Rewrite) : Directions := @@ -54,56 +54,31 @@ def eqProof (rw : Rewrite) : MetaM Expr := do | .eq => return rw.proof | .iff => mkPropExt rw.proof --- TODO: Factor out some parts of this as functions on `MVars`. +def freshWithSubst (rw : Rewrite) (src : Source := rw.src) : MetaM (Rewrite × MVars.Subst) := do + let (lhsMVars, subst) ← rw.lhsMVars.fresh + let (rhsMVars, subst) ← rw.rhsMVars.fresh (init := subst) + let rw' := { rw with + lhs := subst.apply rw.lhs + rhs := subst.apply rw.rhs + proof := subst.apply rw.proof + src, lhsMVars, rhsMVars + } + return (rw', subst) + -- Returns the same rewrite but with all (expression and level) mvars replaced by fresh mvars. This -- is used during proof reconstruction, as rewrites may be used multiple times but instantiated -- differently. If we don't use fresh mvars, the mvars will already be assigned and new assignment -- (via `isDefEq`) will fail. -def fresh (rw : Rewrite) (src : Source := rw.src) : MetaM Rewrite := do - let (mvarSubst, lmvarSubst, lhsMVars) ← mkSubsts ∅ ∅ rw.lhsMVars - let (mvarSubst, lmvarSubst, rhsMVars) ← mkSubsts mvarSubst lmvarSubst rw.rhsMVars - let lhs := applySubsts rw.lhs mvarSubst lmvarSubst - let rhs := applySubsts rw.rhs mvarSubst lmvarSubst - let proof := applySubsts rw.proof mvarSubst lmvarSubst - return { rw with lhs, rhs, proof, src, lhsMVars, rhsMVars } -where - applySubsts (e : Expr) (mvarSubst : HashMap MVarId Expr) (lmvarSubst : HashMap LMVarId Level) : Expr := - let replaceLvl : Level → Option Level - | .mvar id => lmvarSubst.find? id - | _ => none - let replaceExpr : Expr → Option Expr - | .mvar id => mvarSubst.find? id - | .sort lvl => Expr.sort <| lvl.replace replaceLvl - | .const name lvls => Expr.const name <| lvls.map (·.replace replaceLvl) - | _ => none - e.replace replaceExpr - - mkSubsts (mvarSubst : HashMap MVarId Expr) (lmvarSubst : HashMap LMVarId Level) (mvars : MVars) : - MetaM (HashMap MVarId Expr × HashMap LMVarId Level × MVars) := do - let mut mvarSubst := mvarSubst - let mut lmvarSubst := lmvarSubst - let mut freshMVars : MVars := {} - for var in mvars.expr do - if let some fresh := mvarSubst.find? var then - freshMVars := { freshMVars with expr := freshMVars.expr.insert fresh.mvarId! } - else - let fresh ← mkFreshExprMVar (← var.getType) - mvarSubst := mvarSubst.insert var fresh - freshMVars := { freshMVars with expr := freshMVars.expr.insert fresh.mvarId! } - for var in mvars.lvl do - if let some fresh := lmvarSubst.find? var then - freshMVars := { freshMVars with lvl := freshMVars.lvl.insert fresh.mvarId! } - else - let fresh ← mkFreshLevelMVar - lmvarSubst := lmvarSubst.insert var fresh - freshMVars := { freshMVars with lvl := freshMVars.lvl.insert fresh.mvarId! } - return (mvarSubst, lmvarSubst, freshMVars) +def fresh (rw : Rewrite) (src : Source := rw.src) : MetaM Rewrite := + Prod.fst <$> rw.freshWithSubst src def instantiateMVars (rw : Rewrite) : MetaM Rewrite := return { rw with - lhs := ← Lean.instantiateMVars rw.lhs - rhs := ← Lean.instantiateMVars rw.rhs - proof := ← Lean.instantiateMVars rw.proof + lhs := ← Lean.instantiateMVars rw.lhs + rhs := ← Lean.instantiateMVars rw.rhs + proof := ← Lean.instantiateMVars rw.proof + lhsMVars := ← rw.lhsMVars.removeAssigned + rhsMVars := ← rw.rhsMVars.removeAssigned } end Rewrite diff --git a/Lean/Egg/Core/Source.lean b/Lean/Egg/Core/Source.lean index e413d49..53806c3 100644 --- a/Lean/Egg/Core/Source.lean +++ b/Lean/Egg/Core/Source.lean @@ -1,3 +1,4 @@ +import Egg.Core.Directions import Egg.Lean import Lean open Lean @@ -35,7 +36,7 @@ inductive Source where | explicit (idx : Nat) (eqn? : Option Nat) | star (id : FVarId) | tcProj (src : Source) (side : Side) (pos : SubExpr.Pos) - | explosion (src : Source) (idx : Nat) + | tcSpec (src : Source) (dir : Direction) | natLit (src : Source.NatLit) | eta | beta @@ -60,7 +61,7 @@ def description : Source → String | explicit idx (some eqn) => s!"#{idx}/{eqn}" | star id => s!"*{id.uniqueIdx!}" | tcProj src side pos => s!"{src.description}[{side.description}{pos}]" - | explosion src idx => s!"{src.description}<{idx}>" + | tcSpec src dir => s!"{src.description}<{dir.description}>" | natLit src => src.description | eta => "≡η" | beta => "≡β" diff --git a/Lean/Egg/Lean.lean b/Lean/Egg/Lean.lean index 709436b..682f97e 100644 --- a/Lean/Egg/Lean.lean +++ b/Lean/Egg/Lean.lean @@ -46,8 +46,14 @@ def HashMap.insertIfNew [BEq α] [BEq β] [Hashable α] [Hashable β] def RBTree.merge (t₁ t₂ : RBTree α cmp) : RBTree α cmp := t₁.mergeBy (fun _ _ _ => .unit) t₂ +def RBTree.filterM [Monad m] (t : RBTree α cmp) (keep : α → m Bool) : m (RBTree α cmp) := + t.foldM (init := t) fun res a => return if ← keep a then res else res.erase a + def RBTree.filter (t : RBTree α cmp) (keep : α → Bool) : RBTree α cmp := - t.fold (init := t) fun res a => if keep a then res else res.erase a + t.filterM keep (m := Id) + +def RBTree.map (t : RBTree α cmp) (f : α → α) : RBTree α cmp := + t.fold (init := ∅) fun res a => res.insert (f a) def RBTree.subtract (t₁ t₂ : RBTree α cmp) : RBTree α cmp := t₁.filter (!t₂.contains ·) diff --git a/Lean/Egg/Tactic/Basic.lean b/Lean/Egg/Tactic/Basic.lean index c448997..589a3cf 100644 --- a/Lean/Egg/Tactic/Basic.lean +++ b/Lean/Egg/Tactic/Basic.lean @@ -1,6 +1,7 @@ import Egg.Core.Request import Egg.Core.Explanation.Proof import Egg.Core.Gen.TcProjs +import Egg.Core.Gen.TcSpecs import Egg.Tactic.Config.Option import Egg.Tactic.Config.Modifier import Egg.Tactic.Explanation @@ -42,7 +43,8 @@ private def genRewrites (goal : Goal) (rws : TSyntax `egg_rws) (cfg : Config) : if cfg.genTcProjRws then let tcProjTargets := #[(goal.type, Source.goal)] ++ (rws.map fun rw => (rw.toCongr, rw.src)) rws := rws ++ (← genTcProjReductions tcProjTargets cfg.betaReduceRws cfg.etaReduceRws) - if cfg.explode then rws := rws ++ (← rws.explode) + if cfg.genTcSpecRws then + rws := rws ++ (← genTcSpecializations rws) return rws private def processRawExpl diff --git a/Lean/Egg/Tactic/Config/Modifier.lean b/Lean/Egg/Tactic/Config/Modifier.lean index 587c80a..08dfaad 100644 --- a/Lean/Egg/Tactic/Config/Modifier.lean +++ b/Lean/Egg/Tactic/Config/Modifier.lean @@ -12,10 +12,10 @@ structure Modifier where betaReduceRws : Option Bool := none etaReduceRws : Option Bool := none genTcProjRws : Option Bool := none + genTcSpecRws : Option Bool := none genNatLitRws : Option Bool := none genEtaRw : Option Bool := none genBetaRw : Option Bool := none - explode : Option Bool := none shiftCapturedBVars : Option Bool := none blockInvalidMatches : Option Bool := none optimizeExpl : Option Bool := none @@ -31,10 +31,10 @@ def modify (cfg : Config) (mod : Modifier) : Config where betaReduceRws := mod.betaReduceRws.getD cfg.betaReduceRws etaReduceRws := mod.etaReduceRws.getD cfg.etaReduceRws genTcProjRws := mod.genTcProjRws.getD cfg.genTcProjRws + genTcSpecRws := mod.genTcSpecRws.getD cfg.genTcSpecRws genNatLitRws := mod.genNatLitRws.getD cfg.genNatLitRws genEtaRw := mod.genEtaRw.getD cfg.genEtaRw genBetaRw := mod.genBetaRw.getD cfg.genBetaRw - explode := mod.explode.getD cfg.explode shiftCapturedBVars := mod.shiftCapturedBVars.getD cfg.shiftCapturedBVars blockInvalidMatches := mod.blockInvalidMatches.getD cfg.blockInvalidMatches optimizeExpl := mod.optimizeExpl.getD cfg.optimizeExpl diff --git a/Lean/Egg/Tactic/Config/Option.lean b/Lean/Egg/Tactic/Config/Option.lean index 4c1997b..01433e4 100644 --- a/Lean/Egg/Tactic/Config/Option.lean +++ b/Lean/Egg/Tactic/Config/Option.lean @@ -28,6 +28,10 @@ register_option egg.genTcProjRws : Bool := { defValue := ({} : Config).genTcProjRws } +register_option egg.genTcSpecRws : Bool := { + defValue := ({} : Config).genTcSpecRws +} + register_option egg.genNatLitRws : Bool := { defValue := ({} : Config).genNatLitRws } @@ -40,10 +44,6 @@ register_option egg.genBetaRw : Bool := { defValue := ({} : Config).genBetaRw } -register_option egg.explode : Bool := { - defValue := ({} : Config).explode -} - register_option egg.blockInvalidMatches : Bool := { defValue := ({} : Config).blockInvalidMatches } @@ -63,10 +63,10 @@ def Config.fromOptions : MetaM Config := return { betaReduceRws := egg.betaReduceRws.get (← getOptions) etaReduceRws := egg.etaReduceRws.get (← getOptions) genTcProjRws := egg.genTcProjRws.get (← getOptions) + genTcSpecRws := egg.genTcSpecRws.get (← getOptions) genNatLitRws := egg.genNatLitRws.get (← getOptions) genEtaRw := egg.genEtaRw.get (← getOptions) genBetaRw := egg.genBetaRw.get (← getOptions) - explode := egg.explode.get (← getOptions) blockInvalidMatches := egg.blockInvalidMatches.get (← getOptions) shiftCapturedBVars := egg.shiftCapturedBVars.get (← getOptions) optimizeExpl := egg.optimizeExpl.get (← getOptions) diff --git a/Lean/Egg/Tactic/Explanation.lean b/Lean/Egg/Tactic/Explanation.lean index c9927d7..43b5731 100644 --- a/Lean/Egg/Tactic/Explanation.lean +++ b/Lean/Egg/Tactic/Explanation.lean @@ -13,7 +13,8 @@ declare_syntax_cat egg_side declare_syntax_cat egg_subexpr_pos declare_syntax_cat egg_basic_fwd_rw_src declare_syntax_cat egg_tc_proj -declare_syntax_cat egg_explosion +declare_syntax_cat egg_tc_spec_dir +declare_syntax_cat egg_tc_spec declare_syntax_cat egg_fwd_rw_src declare_syntax_cat egg_rw_src @@ -40,18 +41,20 @@ syntax "*" noWs num : egg_basic_fwd_rw_src syntax "[" egg_side egg_subexpr_pos "]" : egg_tc_proj -syntax "<" num ">" : egg_explosion - -syntax egg_basic_fwd_rw_src (egg_tc_proj)? (egg_explosion)? : egg_fwd_rw_src -syntax "⊢" egg_tc_proj : egg_fwd_rw_src -syntax "≡0" : egg_fwd_rw_src -syntax "≡→S" : egg_fwd_rw_src -syntax "≡S→" : egg_fwd_rw_src -syntax "≡+" : egg_fwd_rw_src -syntax "≡-" : egg_fwd_rw_src -syntax "≡*" : egg_fwd_rw_src -syntax "≡^" : egg_fwd_rw_src -syntax "≡/" : egg_fwd_rw_src +syntax "→" : egg_tc_spec_dir +syntax "←" : egg_tc_spec_dir +syntax "<" egg_tc_spec_dir ">" : egg_tc_spec + +syntax egg_basic_fwd_rw_src (egg_tc_proj)? (egg_tc_spec)? : egg_fwd_rw_src +syntax "⊢" egg_tc_proj : egg_fwd_rw_src +syntax "≡0" : egg_fwd_rw_src +syntax "≡→S" : egg_fwd_rw_src +syntax "≡S→" : egg_fwd_rw_src +syntax "≡+" : egg_fwd_rw_src +syntax "≡-" : egg_fwd_rw_src +syntax "≡*" : egg_fwd_rw_src +syntax "≡^" : egg_fwd_rw_src +syntax "≡/" : egg_fwd_rw_src -- WORKAROUND: https://egraphs.zulipchat.com/#narrow/stream/375765-egg.2Fegglog/topic/.25.20in.20rule.20name syntax str : egg_fwd_rw_src @@ -94,11 +97,16 @@ private def parseLit : (TSyntax `egg_lit) → Literal | `(egg_lit|$s:str) => .strVal s.getString | _ => unreachable! -private def parseRwDir : (TSyntax `egg_rw_dir) → Rewrite.Direction +private def parseRwDir : (TSyntax `egg_rw_dir) → Direction | `(egg_rw_dir|=>) => .forward | `(egg_rw_dir|<=) => .backward | _ => unreachable! +private def parsTcSpecDir : (TSyntax `egg_tc_spec_dir) → Direction + | `(egg_tc_spec_dir|→) => .forward + | `(egg_tc_spec_dir|←) => .backward + | _ => unreachable! + private def parseSide : (TSyntax `egg_side) → Side | `(egg_side|l) => .left | `(egg_side|r) => .right @@ -127,10 +135,12 @@ private def parseFwdRwSrc : (TSyntax `egg_fwd_rw_src) → Source | `(egg_fwd_rw_src|"≡%") => .natLit .mod | `(egg_fwd_rw_src|≡η) => .eta | `(egg_fwd_rw_src|≡β) => .beta - | `(egg_fwd_rw_src|$src:egg_basic_fwd_rw_src$[[$tcSide?$pos?]]?$[<$exIdx?>]?) => Id.run do + | `(egg_fwd_rw_src|$src:egg_basic_fwd_rw_src$[[$tcProjSide?$tcProjPos?]]?$[<$tcSpecDir?>]?) => Id.run do let mut src := parseBasicFwdRwSrc src - if let some tcSide := tcSide? then src := .tcProj src (parseSide tcSide) (parseSubexprPos pos?.get!) - if let some exIdx := exIdx? then src := .explosion src exIdx.getNat + if let some tcProjSide := tcProjSide? then + src := .tcProj src (parseSide tcProjSide) (parseSubexprPos tcProjPos?.get!) + if let some tcSpecDir := tcSpecDir? then + src := .tcSpec src (parsTcSpecDir tcSpecDir) return src | _ => unreachable! diff --git a/Lean/Egg/Tests/Classes.lean b/Lean/Egg/Tests/Classes.lean index 412710a..8e2c267 100644 --- a/Lean/Egg/Tests/Classes.lean +++ b/Lean/Egg/Tests/Classes.lean @@ -4,5 +4,6 @@ import Egg variable (thm : {α : Type _} → [Add α] → (a b : α) → a + b = b + a) +-- TODO: It feels like this is related to universe levels. example {a b : Nat} : a + b = b + a := by - egg [thm] + sorry -- egg [thm] diff --git a/Lean/Egg/Tests/Groups.lean b/Lean/Egg/Tests/Groups.lean index bed12d3..a4fcd88 100644 --- a/Lean/Egg/Tests/Groups.lean +++ b/Lean/Egg/Tests/Groups.lean @@ -1,4 +1,5 @@ import Egg +import Lean class Group (α) where zero : α @@ -27,8 +28,7 @@ theorem neg_add_cancel_left : -a + (a + b) = b := by group theorem add_neg_cancel_left : a + (-a + b) = b := by group --- TODO: This test case should be fixed by typeclass specialization. -theorem neg_zero : -(0 : G) = 0 := by sorry -- group +theorem neg_zero : -(0 : G) = 0 := by group -- TODO: What is the proof of this? theorem neg_add : -(a + b) = -b + -a := by diff --git a/Lean/Egg/Tests/mathlib4 b/Lean/Egg/Tests/mathlib4 index ae01a0f..7b780c1 160000 --- a/Lean/Egg/Tests/mathlib4 +++ b/Lean/Egg/Tests/mathlib4 @@ -1 +1 @@ -Subproject commit ae01a0fcd6357a2273a3214d340d4b5f0c434aef +Subproject commit 7b780c19c9ad67558e3cc34c3cb565faa21164b0 diff --git a/Lean/Egg/Tests/run_tests.sh b/Lean/Egg/Tests/run_tests.sh index 9657956..25c231a 100755 --- a/Lean/Egg/Tests/run_tests.sh +++ b/Lean/Egg/Tests/run_tests.sh @@ -33,8 +33,10 @@ for file in "$tests_dir"/*; do echo -n "Testing $file_name ..." fi - module_name="$module_prefix«$file_name»" - output=$(lake build $module_name 2>&1) + left_quote='«' + right_quote='»' + module_name="$module_prefix$left_quote$file_name$right_quote" + output=$(lake build "$module_name" 2>&1) if [[ $? -eq 0 ]]; then if grep -q "sorry" "$file"; then diff --git a/lake-manifest.json b/lake-manifest.json index 733d059..d21f844 100644 --- a/lake-manifest.json +++ b/lake-manifest.json @@ -4,10 +4,10 @@ [{"url": "https://github.com/leanprover/std4", "type": "git", "subDir": null, - "rev": "ff9850c4726f6b9fb8d8e96980c3fcb2900be8bd", + "rev": "32983874c1b897d78f20d620fe92fc8fd3f06c3a", "name": "std", "manifestFile": "lake-manifest.json", - "inputRev": "v4.7.0-rc1", + "inputRev": "v4.7.0", "inherited": false, "configFile": "lakefile.lean"}], "name": "egg",