Skip to content

Commit

Permalink
Generate tc-proj rws and tc-spec rws recursively until fixed point
Browse files Browse the repository at this point in the history
This fixes the first calc step in inv_eq_of_mul_eq_one_left
  • Loading branch information
marcusrossel committed Apr 15, 2024
1 parent 7ba9316 commit ca33766
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 72 deletions.
43 changes: 30 additions & 13 deletions Lean/Egg/Core/Gen/TcProjs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ namespace Egg
-- We expect the `args` to contain `numParams + 1` elements where the `numParams + 1`th argument is
-- the typeclass instance argument for `const`. Also, not arguments can contain loos bvars and the
-- final argument (the typeclass instance) can not be an mvar.
private structure TcProj where
structure TcProj where
const : Name
lvls : List Level
args : Array Expr
deriving BEq, Hashable

abbrev TcProjIndex := HashMap TcProj Source

private def TcProj.reductionRewrite? (proj : TcProj) (src : Source) (beta eta : Bool) : MetaM (Option Rewrite) := do
private def TcProj.reductionRewrite? (proj : TcProj) (src : Source) (beta eta : Bool) :
MetaM (Option Rewrite) := do
let app := mkAppN (.const proj.const proj.lvls) proj.args
let reduced ← withReducibleAndInstances do reduceAll app
if app == reduced then return none
Expand All @@ -26,14 +27,16 @@ private def TcProj.reductionRewrite? (proj : TcProj) (src : Source) (beta eta :
return rw

private structure State where
projs : HashMap TcProj Source := ∅
args : Array Expr := #[]
pos : SubExpr.Pos := .root
projs : TcProjIndex := ∅
args : Array Expr := #[]
pos : SubExpr.Pos := .root
covered : HashSet TcProj := ∅
deriving Inhabited

private partial def tcProjs (e : Expr) (src : Source) (side? : Option Side) (init : TcProjIndex) :
private partial def tcProjs
(e : Expr) (src : Source) (side? : Option Side) (covered : HashSet TcProj) :
MetaM TcProjIndex :=
State.projs <$> go e { projs := init }
State.projs <$> go e { covered }
where
go : Expr → State → MetaM State
| .const c lvls => visitConst c lvls
Expand All @@ -47,8 +50,13 @@ where
unless info.fromClass && s.args.size > info.numParams do return s
let args := s.args[:info.numParams + 1].toArray
if args.back.isMVar || args.any (·.hasLooseBVars) then return s
let projs := s.projs.insertIfNew { const, lvls, args } (.tcProj src side? s.pos)
return { s with projs }
let proj : TcProj := { const, lvls, args }
return { s with
projs :=
if s.covered.contains proj
then s.projs
else s.projs.insertIfNew proj (.tcProj src side? s.pos)
}

visitBindingBody (b : Expr) (s : State) : MetaM State := do
let s' ← go b { s with pos := s.pos.pushBindingBody }
Expand Down Expand Up @@ -78,8 +86,17 @@ def Guides.tcProjTargets (guides : Guides) : Array TcProjTarget :=
guides.map fun guide => { expr := guide.expr, src := guide.src, side? := none }

-- Note: This function expects its inputs' expressions to be normalized (cf. `Egg.normalize`).
def genTcProjReductions (targets : Array TcProjTarget) (beta eta : Bool) : MetaM Rewrites := do
let mut projs : TcProjIndex := ∅
def genTcProjReductions'
(targets : Array TcProjTarget) (covered : HashSet TcProj) (beta eta : Bool) :
MetaM (Rewrites × HashSet TcProj) := do
let mut covered := covered
let mut rws := #[]
for target in targets do
projs ← tcProjs target.expr target.src target.side? projs
projs.toArray.filterMapM fun (proj, src) => proj.reductionRewrite? src beta eta
let projs ← tcProjs target.expr target.src target.side? covered
for (proj, src) in projs.toArray do
covered := covered.insert proj
if let some rw ← proj.reductionRewrite? src beta eta then rws := rws.push rw
return (rws, covered)

def genTcProjReductions (targets : Array TcProjTarget) (beta eta : Bool) : MetaM Rewrites :=
Prod.fst <$> genTcProjReductions' targets ∅ beta eta
9 changes: 6 additions & 3 deletions Lean/Egg/Core/Gen/TcSpecs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ private partial def genSpecialization (rw : Rewrite) (dir : Direction) (missing
return if rw.validDirs.contains dir then rw else none
where
instanceForType? (type : Expr) : MetaM (Option Expr) := do
if let some inst ← findLocalDeclWithType? type
then return (Expr.fvar inst)
else optional (synthInstance type)
if let some inst ← findLocalDeclWithType? type then
return (Expr.fvar inst)
else if let some inst ← optional (synthInstance type) then
normalize inst false false -- TODO: How should beta/eta be applied here?
else
return none

private def genTcSpecializationsForRw (rw : Rewrite) : MetaM Rewrites := do
let missingOnLhs := rw.rhsMVars.tc.subtract rw.lhsMVars.tc
Expand Down
36 changes: 26 additions & 10 deletions Lean/Egg/Tactic/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Egg.Tactic.Base
import Egg.Tactic.Guides
import Egg.Tactic.Rewrites
import Egg.Tactic.Trace
import Std.Tactic.Exact
import Lean

open Lean Meta Elab Tactic
Expand Down Expand Up @@ -44,15 +45,30 @@ where
else
throwError "expected goal to be of type '=' or '↔', but found:\n{← ppExpr goalType}"

private def genRewrites (goal : Goal) (rws : TSyntax `egg_rws) (guides : Guides) (cfg : Config) :
TacticM Rewrites := do
let mut rws ← Rewrites.parse cfg.betaReduceRws cfg.etaReduceRws rws
if cfg.genTcProjRws then
let tcProjTargets := goal.tcProjTargets ++ guides.tcProjTargets ++ rws.tcProjTargets
rws := rws ++ (← genTcProjReductions tcProjTargets cfg.betaReduceRws cfg.etaReduceRws)
if cfg.genTcSpecRws then
rws := rws ++ (← genTcSpecializations rws)
return rws
private partial def genRewrites
(goal : Goal) (rws : TSyntax `egg_rws) (guides : Guides) (cfg : Config) : TacticM Rewrites := do
let rws ← Rewrites.parse cfg.betaReduceRws cfg.etaReduceRws rws
return rws ++ (← genTcRws rws)
where
genTcRws (rws : Rewrites) : TacticM Rewrites := do
let mut projTodo := #[]
let mut specTodo := #[]
let mut tcRws := #[]
let mut covered : HashSet TcProj := ∅
if cfg.genTcProjRws then projTodo := goal.tcProjTargets ++ guides.tcProjTargets ++ rws.tcProjTargets
if cfg.genTcSpecRws then specTodo := rws
while (cfg.genTcProjRws && !projTodo.isEmpty) || (cfg.genTcSpecRws && !specTodo.isEmpty) do
if cfg.genTcProjRws then
let (projRws, cov) ← genTcProjReductions' projTodo covered cfg.betaReduceRws cfg.etaReduceRws
covered := cov
specTodo := specTodo ++ projRws
tcRws := tcRws ++ projRws
if cfg.genTcSpecRws then
let specRws ← genTcSpecializations specTodo
specTodo := #[]
projTodo := specRws.tcProjTargets
tcRws := tcRws ++ specRws
return tcRws

private def processRawExpl
(rawExpl : Explanation.Raw) (goal : Goal) (rws : Rewrites) (cfg : Config.Debug)
Expand All @@ -69,7 +85,7 @@ private def processRawExpl
if let some base := goal.base? then proof ← mkEqMP proof (.fvar base)
withTraceNode `egg.reconstruction (fun _ => return "Final Proof") do
trace[egg.reconstruction] proof
goal.id.assign proof
goal.id.assignIfDefeq proof

open Config.Modifier (egg_cfg_mod)

Expand Down
60 changes: 28 additions & 32 deletions Lean/Egg/Tactic/Explanation.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ declare_syntax_cat egg_basic_fwd_rw_src
declare_syntax_cat egg_tc_proj
declare_syntax_cat egg_tc_spec_dir
declare_syntax_cat egg_tc_spec
declare_syntax_cat egg_tc_extension
declare_syntax_cat egg_fwd_rw_src
declare_syntax_cat egg_rw_src

Expand All @@ -38,31 +39,32 @@ syntax ("/" num)+ : egg_subexpr_pos

syntax "#" noWs num (noWs "/" noWs num)? : egg_basic_fwd_rw_src
syntax "*" noWs num : egg_basic_fwd_rw_src
syntax "⊢" : egg_basic_fwd_rw_src
syntax "↣" noWs num : egg_basic_fwd_rw_src

syntax "[" (egg_side)? egg_subexpr_pos "]" : egg_tc_proj

syntax "→" : egg_tc_spec_dir
syntax "←" : egg_tc_spec_dir
syntax "<" egg_tc_spec_dir ">" : egg_tc_spec

syntax egg_basic_fwd_rw_src (egg_tc_proj)? (egg_tc_spec)? : egg_fwd_rw_src
syntax "⊢" egg_tc_proj (egg_tc_spec)? : egg_fwd_rw_src
syntax "↣" num egg_tc_proj (egg_tc_spec)? : egg_fwd_rw_src
syntax "≡0" : egg_fwd_rw_src
syntax "≡→S" : egg_fwd_rw_src
syntax "≡S→" : egg_fwd_rw_src
syntax "≡+" : egg_fwd_rw_src
syntax "≡-" : egg_fwd_rw_src
syntax "≡*" : egg_fwd_rw_src
syntax "≡^" : egg_fwd_rw_src
syntax "≡/" : egg_fwd_rw_src

syntax egg_tc_proj : egg_tc_extension
syntax egg_tc_spec : egg_tc_extension

syntax egg_basic_fwd_rw_src (noWs egg_tc_extension)* : egg_fwd_rw_src
syntax "≡η" : egg_fwd_rw_src
syntax "≡β" : egg_fwd_rw_src
syntax "≡0" : egg_fwd_rw_src
syntax "≡→S" : egg_fwd_rw_src
syntax "≡S→" : egg_fwd_rw_src
syntax "≡+" : egg_fwd_rw_src
syntax "≡-" : egg_fwd_rw_src
syntax "≡*" : egg_fwd_rw_src
syntax "≡^" : egg_fwd_rw_src
syntax "≡/" : egg_fwd_rw_src
-- WORKAROUND: https://egraphs.zulipchat.com/#narrow/stream/375765-egg.2Fegglog/topic/.25.20in.20rule.20name
syntax str : egg_fwd_rw_src
-- syntax "≡%" : egg_fwd_rw_src

syntax "≡η" : egg_fwd_rw_src
syntax "≡β" : egg_fwd_rw_src
syntax str : egg_fwd_rw_src
-- syntax "≡%" : egg_fwd_rw_src

syntax egg_fwd_rw_src (noWs "-rev")? : egg_rw_src

Expand Down Expand Up @@ -121,8 +123,15 @@ private def parseSubexprPos : (TSyntax `egg_subexpr_pos) → SubExpr.Pos
private def parseBasicFwdRwSrc : (TSyntax `egg_basic_fwd_rw_src) → Source
| `(egg_basic_fwd_rw_src|#$idx$[/$eqn?]?) => .explicit idx.getNat (eqn?.map TSyntax.getNat)
| `(egg_basic_fwd_rw_src|*$idx) => .star (.fromUniqueIdx idx.getNat)
| `(egg_basic_fwd_rw_src|⊢) => .goal
| `(egg_basic_fwd_rw_src|↣$idx) => .guide idx.getNat
| _ => unreachable!

private def parseTcExtension (src : Source) : (TSyntax `egg_tc_extension) → Source
| `(egg_tc_extension|[$[$side?]?$pos]) => .tcProj src (side?.map parseSide) (parseSubexprPos pos)
| `(egg_tc_extension|<$dir>) => .tcSpec src (parsTcSpecDir dir)
| _ => unreachable!

private def parseFwdRwSrc : (TSyntax `egg_fwd_rw_src) → Source
| `(egg_fwd_rw_src|≡0) => .natLit .zero
| `(egg_fwd_rw_src|≡→S) => .natLit .toSucc
Expand All @@ -135,21 +144,8 @@ private def parseFwdRwSrc : (TSyntax `egg_fwd_rw_src) → Source
| `(egg_fwd_rw_src|"≡%") => .natLit .mod
| `(egg_fwd_rw_src|≡η) => .eta
| `(egg_fwd_rw_src|≡β) => .beta
| `(egg_fwd_rw_src|⊢[$tcProjSide$tcProjPos]$[<$tcSpecDir?>]?) => Id.run do
let mut src := Source.tcProj .goal (parseSide tcProjSide) (parseSubexprPos tcProjPos)
if let some tcSpecDir := tcSpecDir? then src := .tcSpec src (parsTcSpecDir tcSpecDir)
return src
| `(egg_fwd_rw_src|↣$idx[$[$tcProjSide]?$tcProjPos]$[<$tcSpecDir?>]?) => Id.run do
let mut src := Source.tcProj (.guide idx.getNat) (tcProjSide.map parseSide) (parseSubexprPos tcProjPos)
if let some tcSpecDir := tcSpecDir? then src := .tcSpec src (parsTcSpecDir tcSpecDir)
return src
| `(egg_fwd_rw_src|$src:egg_basic_fwd_rw_src$[[$tcProjSide?$tcProjPos?]]?$[<$tcSpecDir?>]?) => Id.run do
let mut src := parseBasicFwdRwSrc src
if let some tcProjSide := tcProjSide? then
src := .tcProj src (parseSide tcProjSide) (parseSubexprPos tcProjPos?.get!)
if let some tcSpecDir := tcSpecDir? then
src := .tcSpec src (parsTcSpecDir tcSpecDir)
return src
| `(egg_fwd_rw_src|$src:egg_basic_fwd_rw_src$tcExts:egg_tc_extension*) =>
tcExts.foldl (init := parseBasicFwdRwSrc src) parseTcExtension
| _ => unreachable!

private def parseRwSrc : (TSyntax `egg_rw_src) → Rewrite.Descriptor
Expand Down
2 changes: 1 addition & 1 deletion Lean/Egg/Tests/Groups.lean
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ theorem neg_add' : -(a + b) = -b + -a := by

-- BUG: Try adding a hypothesis. This should cause the backend to crash.
theorem inv_inv' : -(-a) = a := by
group using (-a + a)
group using -a + a
13 changes: 0 additions & 13 deletions Lean/Egg/Tests/WIP.lean
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,3 @@ set_option egg.shiftCapturedBVars true in
example : True := by
have : (fun x => (fun a => (fun a => a) a) 0) = (fun x => x) := by sorry -- egg [thm₂]
constructor


-- Unrelated to capture avoidance:
--
-- TODO: If we have a theorem like `(fun a b => a) x y = x`, it's only applicable in the forward
-- direction. But once we β-reduce it, it's applicable in both directions. I think that can
-- cause problems during reconstruction as we cannot reconstruct the assignment of `y`.



-- TODO: Something about tc-specialization isn'T quite working yet.
-- The first calc step in inv_eq_of_mul_eq_one_left should be doable with egg, but no tc-spec
-- is generated.

0 comments on commit ca33766

Please sign in to comment.