Skip to content

Commit

Permalink
Remove tracking for whether the goal contains a binder
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusrossel committed Feb 1, 2025
1 parent 8b8ddac commit 669328b
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 42 deletions.
13 changes: 2 additions & 11 deletions Lean/Egg/Core/Encode/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@ private def encodeLevel : Level → EncodeM Expression
else return s!"(uvar {id.uniqueIdx!})"

-- Note: This function expects its input expression to be normalized (cf. `Egg.normalize`).
--
-- Returns the encoded expression with a flag indicating whether it contains a binder.
partial def encode' (e : Expr) (cfg : Config.Encoding) : MetaM (Expression × Bool) := do
let (expr, { usedBinder, .. }) ← (go e).run { config := cfg }
return (expr, usedBinder)
partial def encode (e : Expr) (cfg : Config.Encoding) : MetaM Expression := do
Prod.fst <$> (go e).run { config := cfg }
where
go (e : Expr) : EncodeM Expression :=
withCache e do
Expand Down Expand Up @@ -75,19 +72,13 @@ where
lvls.foldlM (init := "") (return s!"{·} {← encodeLevel ·}")

encodeLambda (ty b : Expr) : EncodeM Expression := do
setUsedBinder
-- It's critical that we encode `ty` outside of the `withInstantiatedBVar` block, as otherwise
-- the bvars in `encTy` are incorrectly shifted by 1.
let encTy ← go ty
withInstantiatedBVar ty b fun var? body => return s!"(λ {var?}{encTy} {← go body})"

encodeForall (ty b : Expr) : EncodeM Expression := do
setUsedBinder
-- It's critical that we encode `ty` outside of the `withInstantiatedBVar` block, as otherwise
-- the bvars in `encTy` are incorrectly shifted by 1.
let encTy ← go ty
withInstantiatedBVar ty b fun var? body => return s!"(∀ {var?}{encTy} {← go body})"

-- Note: This function expects its input expression to be normalized (cf. `Egg.normalize`).
def encode (e : Expr) (cfg : Config.Encoding) : MetaM Expression :=
Prod.fst <$> encode' e cfg
10 changes: 3 additions & 7 deletions Lean/Egg/Core/Encode/EncodeM.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ namespace Egg
abbrev Expression := String

structure EncodeM.State where
config : Config.Encoding
bvars : List FVarId := []
cache : HashMap Expr Expression := ∅
usedBinder : Bool := false
config : Config.Encoding
bvars : List FVarId := []
cache : HashMap Expr Expression := ∅

abbrev EncodeM := StateT EncodeM.State MetaM

Expand Down Expand Up @@ -64,6 +63,3 @@ def withoutShapes (m : EncodeM Expression) : EncodeM Expression := do
let enc ← m
set { s with config.shapes := shapes }
return enc

def setUsedBinder : EncodeM Unit :=
modify ({ · with usedBinder := true })
15 changes: 4 additions & 11 deletions Lean/Egg/Core/Request/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,15 @@ structure _root_.Egg.Request where
vizPath : String
cfg : Request.Config

-- Returns the encoded request with a flag indicating whether the proof goal contains a binder.
def encoding' (goal : Congr) (rws : Rewrites) (guides : Guides) (cfg : Config) :
MetaM (Request × Bool) := do
let (lhs, lhsBinder) ← encode' goal.lhs cfg
let (rhs, rhsBinder) ← encode' goal.rhs cfg
let req := {
lhs, rhs
def encoding (goal : Congr) (rws : Rewrites) (guides : Guides) (cfg : Config) : MetaM Request :=
return {
lhs := ← encode goal.lhs cfg
rhs := ← encode goal.rhs cfg
rws := ← rws.encode cfg
guides := ← guides.encode cfg
vizPath := cfg.vizPath.getD ""
cfg
}
return (req, lhsBinder || rhsBinder)

def encoding (goal : Congr) (rws : Rewrites) (guides : Guides) (cfg : Config) : MetaM Request :=
Prod.fst <$> encoding' goal rws guides cfg

-- IMPORTANT: The C interface to egg depends on the order of these constructors.
inductive Result.StopReason where
Expand Down
19 changes: 9 additions & 10 deletions Lean/Egg/Tactic/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -64,43 +64,42 @@ where
let guides ← do if cfg.derivedGuides then pure (guides ++ (← genDerivedGuides goal.toCongr rws)) else pure guides
runEqSat goal rws guides cfg
match res with
| some (proof, proofTime, result, goalContainsBinder) =>
| some (proof, proofTime, result) =>
if cfg.reporting then
let totalTime := (← IO.monoMsNow) - startTime
logInfo (s!"egg succeeded " ++ formatReport cfg.flattenReports result.report totalTime proofTime result.expl goalContainsBinder)
logInfo (s!"egg succeeded " ++ formatReport cfg.flattenReports result.report totalTime proofTime result.expl)
let (subgoals, _, proof) ← openAbstractMVarsResult proof
appendGoals <| Array.toList <| subgoals.map (·.mvarId!)
goal.id.assignIfDefeq' proof
if let some tk := calcifyTk? then calcify tk proof goal.intros.unzip.snd
| none => goal.id.admit
runEqSat
(goal : Goal) (rws : Rewrites) (guides : Guides) (cfg : Config) :
TacticM <| Option (AbstractMVarsResult × Nat × Request.Result × Bool) := do
let (req, goalContainsBinder) ← Request.encoding' goal.toCongr rws guides cfg
TacticM <| Option (AbstractMVarsResult × Nat × Request.Result) := do
let req ← Request.encoding goal.toCongr rws guides cfg
withTraceNode `egg.encoded (fun _ => return "Encoded") do req.trace `egg.encoded
if let .beforeEqSat := cfg.exitPoint then return none
let result ← req.run cfg.explLengthLimit (onEqSatFailure cfg goalContainsBinder)
let result ← req.run cfg.explLengthLimit (onEqSatFailure cfg)
result.expl.trace `egg.explanation.steps
if let .beforeProof := cfg.exitPoint then return none
let beforeProof ← IO.monoMsNow
match ← resultToProof result goal rws cfg cfg.retryWithShapes cfg.proofFuel? with
| .proof prf =>
let proofTime := (← IO.monoMsNow) - beforeProof
return some (prf, proofTime, result, goalContainsBinder)
return some (prf, proofTime, result)
| .retryWithShapes => runEqSat goal rws guides { cfg with shapes := true }
onEqSatFailure (cfg : Config) (goalContainsBinder : Bool) (report : Request.Result.Report) :
Request.Failure → MetaM MessageData
onEqSatFailure (cfg : Config) (report : Request.Result.Report) : Request.Failure → MetaM MessageData
| .backend msg? => do
let mut msg := msg?
if msg.isEmpty then
let reasonMsg := if report.reasonMsg.isEmpty then "" else s!": {report.reasonMsg}"
msg := s!"egg failed to prove the goal ({report.stopReason.description}{reasonMsg}) "
unless cfg.reporting do return msg
return msg ++ formatReport cfg.flattenReports report (goalContainsBinder := goalContainsBinder)
return msg ++ formatReport cfg.flattenReports report
| .explLength len => do
let msg := s!"egg found an explanation exceeding the length limit ({len} vs {cfg.explLengthLimit})\nYou can increase this limit using 'set_option egg.explLengthLimit <num>'.\n"
unless cfg.reporting do return msg
return msg ++ formatReport cfg.flattenReports report (goalContainsBinder := goalContainsBinder)
return msg ++ formatReport cfg.flattenReports report

syntax &"egg " egg_cfg_mod egg_premises (egg_guides)? : tactic
elab_rules : tactic
Expand Down
5 changes: 2 additions & 3 deletions Lean/Egg/Tactic/Trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ def Config.Debug.ExitPoint.format : Config.Debug.ExitPoint → Format

nonrec def formatReport
(flat : Bool) (rep : Request.Result.Report) (totalDuration? proofDuration? : Option Nat := none)
(expl? : Option Explanation := none) (goalContainsBinder : Bool) : Format :=
(expl? : Option Explanation := none) : Format :=
if flat then
"(" ++ (if let some d := totalDuration? then format d else "-") ++ "," ++
(format <| (1000 * rep.time).toUInt64.toNat) ++ "," ++
(if let some d := proofDuration? then format d else "-") ++ "," ++ (format rep.iterations) ++
"," ++ (format rep.nodeCount) ++ "," ++ (format rep.classCount) ++ "," ++
(if let some e := expl? then format e.steps.size ++ s!",{e.involvesBinderRewrites}" else "-,-")
++ "," ++ s!"{goalContainsBinder}" ++ ")"
++ ")"
else
(if let some d := totalDuration? then "\ntotal time: " ++ format d ++ "ms\n" else "") ++
"eqsat time: " ++ (format <| (1000 * rep.time).toUInt64.toNat) ++ "ms\n" ++
Expand All @@ -46,7 +46,6 @@ nonrec def formatReport
"nodes: " ++ (format rep.nodeCount) ++ "\n" ++
"classes: " ++ (format rep.classCount) ++ "\n" ++
(if let some e := expl? then "expl steps: " ++ format e.steps.size ++ s!"\nbinder rws: {e.involvesBinderRewrites}\n" else "") ++
s!"⊢ binders: {goalContainsBinder}" ++
(if rep.rwStats.isEmpty then "" else s!"\nrw stats:\n{rep.rwStats}")

def MVars.Property.toMessageData : MVars.Property → MessageData
Expand Down

0 comments on commit 669328b

Please sign in to comment.