Skip to content

Commit

Permalink
Add todo about tc proj generation
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusrossel committed Apr 16, 2024
1 parent 8130010 commit 6dfbee2
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 20 deletions.
22 changes: 12 additions & 10 deletions Lean/Egg/Core/Gen/TcProjs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ private structure State where
covered : HashSet TcProj := ∅
deriving Inhabited

private def State.covers (s : State) (proj : TcProj) : Bool :=
s.covered.contains proj || s.projs.contains proj

private partial def tcProjs
(e : Expr) (src : Source) (side? : Option Side) (covered : HashSet TcProj) :
MetaM TcProjIndex :=
Expand All @@ -51,12 +54,9 @@ where
let args := s.args[:info.numParams + 1].toArray
if args.back.isMVar || args.any (·.hasLooseBVars) then return s
let proj : TcProj := { const, lvls, args }
return { s with
projs :=
if s.covered.contains proj
then s.projs
else s.projs.insertIfNew proj (.tcProj src side? s.pos)
}
if s.covers proj
then return s
else return { s with projs := s.projs.insert proj (.tcProj src side? s.pos) }

visitBindingBody (b : Expr) (s : State) : MetaM State := do
let s' ← go b { s with pos := s.pos.pushBindingBody }
Expand Down Expand Up @@ -85,8 +85,13 @@ def Rewrites.tcProjTargets (rws : Rewrites) : Array TcProjTarget := Id.run do
def Guides.tcProjTargets (guides : Guides) : Array TcProjTarget :=
guides.map fun guide => { expr := guide.expr, src := guide.src, side? := none }

-- TODO: This still produces many redundant rewrites which differ only by mvars. Is there an
-- efficient way to check if two `TcProj`s are equal up to mvar renaming?
-- Note that for this check to be valid, you also need to know which mvars are "local" and
-- which are ambient.
--
-- Note: This function expects its inputs' expressions to be normalized (cf. `Egg.normalize`).
def genTcProjReductions'
def genTcProjReductions
(targets : Array TcProjTarget) (covered : HashSet TcProj) (beta eta : Bool) :
MetaM (Rewrites × HashSet TcProj) := do
let mut covered := covered
Expand All @@ -97,6 +102,3 @@ def genTcProjReductions'
covered := covered.insert proj
if let some rw ← proj.reductionRewrite? src beta eta then rws := rws.push rw
return (rws, covered)

def genTcProjReductions (targets : Array TcProjTarget) (beta eta : Bool) : MetaM Rewrites :=
Prod.fst <$> genTcProjReductions' targets ∅ beta eta
4 changes: 0 additions & 4 deletions Lean/Egg/Lean.lean
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ def LMVarId.fromUniqueIdx (idx : Nat) : LMVarId :=

deriving instance BEq, Hashable for SubExpr.Pos

def HashMap.insertIfNew [BEq α] [BEq β] [Hashable α] [Hashable β]
(m : HashMap α β) (a : α) (b : β) : HashMap α β :=
if m.contains a then m else m.insert a b

def RBTree.merge (t₁ t₂ : RBTree α cmp) : RBTree α cmp :=
t₁.mergeBy (fun _ _ _ => .unit) t₂

Expand Down
8 changes: 4 additions & 4 deletions Lean/Egg/Tactic/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ where
if cfg.genTcSpecRws then specTodo := rws
while (cfg.genTcProjRws && !projTodo.isEmpty) || (cfg.genTcSpecRws && !specTodo.isEmpty) do
if cfg.genTcProjRws then
let (projRws, cov) ← genTcProjReductions' projTodo covered cfg.betaReduceRws cfg.etaReduceRws
let (projRws, cov) ← genTcProjReductions projTodo covered cfg.betaReduceRws cfg.etaReduceRws
covered := cov
specTodo := specTodo ++ projRws
tcRws := tcRws ++ projRws
tcRws := tcRws ++ projRws
if cfg.genTcSpecRws then
let specRws ← genTcSpecializations specTodo
specTodo := #[]
projTodo := specRws.tcProjTargets
tcRws := tcRws ++ specRws
tcRws := tcRws ++ specRws
return tcRws

private def processRawExpl
Expand Down Expand Up @@ -117,5 +117,5 @@ elab "egg " mod:egg_cfg_mod rws:egg_rws base:(egg_base)? guides:(egg_guides)? :
let rawExpl := req.run
processRawExpl rawExpl goal rws cfg.toDebug amb

-- WORKAROUND: This fixes `Tests/EndOfInput`.
-- WORKAROUND: This fixes `Tests/EndOfInput *`.
macro "egg" mod:egg_cfg_mod : tactic => `(tactic| egg $mod)
2 changes: 0 additions & 2 deletions Lean/Egg/Tactic/Trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ initialize registerTraceClass `egg.reconstruction (inherited := true)

namespace Egg

#check MessageData

nonrec def MVars.format (mvars : MVars) : MetaM Format := do
let expr := format <| ← mvars.expr.toList.mapM (ppExpr <| Expr.mvar ·)
let tc := format <| ← mvars.tc.toList.mapM (ppExpr <| Expr.mvar ·)
Expand Down

0 comments on commit 6dfbee2

Please sign in to comment.