Skip to content

Commit

Permalink
Complete support for guides
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusrossel committed Apr 12, 2024
1 parent f6cc768 commit d83ed78
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 48 deletions.
1 change: 1 addition & 0 deletions Lean/Egg.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ import Egg.Tactic.Config.Option
import Egg.Tactic.Base
import Egg.Tactic.Basic
import Egg.Tactic.Explanation
import Egg.Tactic.Guides
import Egg.Tactic.Rewrites
import Egg.Tactic.Trace
25 changes: 12 additions & 13 deletions Lean/Egg/Core/Encode/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@ private def Expression.erased : Expression :=

open EncodeM

private def encodeLevel : Level → Source → EncodeM Expression
| .zero, _ => return "0"
| .succ l, k => return s!"(succ {← encodeLevel l k})"
| .max l₁ l₂, k => return s!"(max {← encodeLevel l₁ k} {← encodeLevel l₂ k})"
| .imax l₁ l₂, k => return s!"(imax {← encodeLevel l₁ k} {← encodeLevel l₂ k})"
| .mvar id, .goal => return s!"(uvar {id.uniqueIdx!})"
| .mvar id, _ => return s!"?{id.uniqueIdx!}"
| .param name, _ => return s!"(param {name})"
private def encodeLevel (src : Source) : Level → EncodeM Expression
| .zero => return "0"
| .succ l => return s!"(succ {← encodeLevel src l})"
| .max l₁ l₂ => return s!"(max {← encodeLevel src l₁} {← encodeLevel src l₂})"
| .imax l₁ l₂ => return s!"(imax {← encodeLevel src l₁} {← encodeLevel src l₂})"
| .mvar id => return if src.isRewrite then s!"?{id.uniqueIdx!}" else s!"(uvar {id.uniqueIdx!})"
| .param name => return s!"(param {name})"

-- Note: This function expects its input expression to be normalized (cf. `Egg.normalize`).
partial def encode (e : Expr) (src : Source) (cfg : Config.Encoding) : MetaM Expression :=
Expand All @@ -35,7 +34,7 @@ where
| .bvar idx => return s!"(bvar {idx})"
| .fvar id => encodeFVar id
| .mvar id => encodeMVar id
| .sort lvl => return s!"(sort {← encodeLevel lvl (← exprSrc)})"
| .sort lvl => return s!"(sort {← encodeLevel (← exprSrc) lvl})"
| .const name lvls => return s!"(const {name}{← encodeConstLvls lvls})"
| .app fn arg => return s!"(app {← go fn} {← go arg})"
| .lam _ ty b _ => encodeLam ty b
Expand All @@ -50,12 +49,12 @@ where
else return s!"(fvar {id.uniqueIdx!})"

encodeMVar (id : MVarId) : EncodeM Expression := do
match ← exprSrc with
| .goal => return s!"(mvar {id.uniqueIdx!})"
| _ => return s!"?{id.uniqueIdx!}"
if (← exprSrc).isRewrite
then return s!"?{id.uniqueIdx!}"
else return s!"(mvar {id.uniqueIdx!})"

encodeConstLvls (lvls : List Level) : EncodeM Expression :=
lvls.foldlM (init := "") (return s!"{·} {← encodeLevel · (← exprSrc)}")
lvls.foldlM (init := "") (return s!"{·} {← encodeLevel (← exprSrc) ·}")

encodeLam (ty b : Expr) : EncodeM Expression := do
let dom ← if (← config).eraseLambdaDomains then pure Expression.erased else go ty
Expand Down
27 changes: 21 additions & 6 deletions Lean/Egg/Core/Gen/TcProjs.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Egg.Core.Rewrites
import Egg.Core.Guides
import Lean
open Lean Meta

Expand Down Expand Up @@ -30,7 +31,7 @@ private structure State where
pos : SubExpr.Pos := .root
deriving Inhabited

private partial def tcProjs (e : Expr) (src : Source) (side : Side) (init : TcProjIndex) :
private partial def tcProjs (e : Expr) (src : Source) (side? : Option Side) (init : TcProjIndex) :
MetaM TcProjIndex :=
State.projs <$> go e { projs := init }
where
Expand All @@ -46,7 +47,7 @@ where
unless info.fromClass && s.args.size > info.numParams do return s
let args := s.args[:info.numParams + 1].toArray
if args.back.isMVar || args.any (·.hasLooseBVars) then return s
let projs := s.projs.insertIfNew { const, lvls, args } (.tcProj src side s.pos)
let projs := s.projs.insertIfNew { const, lvls, args } (.tcProj src side? s.pos)
return { s with projs }

visitBindingBody (b : Expr) (s : State) : MetaM State := do
Expand All @@ -61,10 +62,24 @@ where
let s' ← go arg { s with args := #[], pos := s.pos.pushAppArg }
return { s' with args := s.args, pos := s.pos }

structure TcProjTarget where
expr : Expr
src : Source
side? : Option Side

def Rewrites.tcProjTargets (rws : Rewrites) : Array TcProjTarget := Id.run do
let mut sources : Array TcProjTarget := #[]
for rw in rws do
sources := sources.push { expr := rw.lhs, src := rw.src, side? := some .left }
sources := sources.push { expr := rw.rhs, src := rw.src, side? := some .right }
return sources

def Guides.tcProjTargets (guides : Guides) : Array TcProjTarget :=
guides.map fun guide => { expr := guide.expr, src := guide.src, side? := none }

-- Note: This function expects its inputs' expressions to be normalized (cf. `Egg.normalize`).
def genTcProjReductions (targets : Array (Congr × Source)) (beta eta : Bool) : MetaM Rewrites := do
def genTcProjReductions (targets : Array TcProjTarget) (beta eta : Bool) : MetaM Rewrites := do
let mut projs : TcProjIndex := ∅
for (cgr, src) in targets do
projs ← tcProjs cgr.lhs src .left projs
projs ← tcProjs cgr.rhs src .right projs
for target in targets do
projs ← tcProjs target.expr target.src target.side? projs
projs.toArray.filterMapM fun (proj, src) => proj.reductionRewrite? src beta eta
8 changes: 8 additions & 0 deletions Lean/Egg/Core/Guides.lean
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import Egg.Core.Source
import Egg.Core.Normalize
import Lean
open Lean

namespace Egg

structure Guide where
private mk ::
expr : Expr
src : Source

def Guide.from (expr : Expr) (src : Source) : MetaM Guide :=
return {
expr := ← normalize expr false false
src
}

abbrev Guides := Array Guide
8 changes: 6 additions & 2 deletions Lean/Egg/Core/Source.lean
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ inductive Source where
| guide (idx : Nat)
| explicit (idx : Nat) (eqn? : Option Nat)
| star (id : FVarId)
| tcProj (src : Source) (side : Side) (pos : SubExpr.Pos)
| tcProj (src : Source) (side? : Option Side) (pos : SubExpr.Pos)
| tcSpec (src : Source) (dir : Direction)
| natLit (src : Source.NatLit)
| eta
Expand All @@ -61,7 +61,7 @@ def description : Source → String
| explicit idx none => s!"#{idx}"
| explicit idx (some eqn) => s!"#{idx}/{eqn}"
| star id => s!"*{id.uniqueIdx!}"
| tcProj src side pos => s!"{src.description}[{side.description}{pos}]"
| tcProj src side pos => s!"{src.description}[{side.map (·.description) |>.getD ""}{pos}]"
| tcSpec src dir => s!"{src.description}<{dir.description}>"
| natLit src => src.description
| eta => "≡η"
Expand All @@ -70,6 +70,10 @@ def description : Source → String
instance : ToString Source where
toString := description

def isRewrite : Source → Bool
| goal | guide _ => false
| _ => true

def isDefEq : Source → Bool
| natLit _ | eta | beta => true
| _ => false
22 changes: 15 additions & 7 deletions Lean/Egg/Tactic/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Egg.Tactic.Config.Option
import Egg.Tactic.Config.Modifier
import Egg.Tactic.Explanation
import Egg.Tactic.Base
import Egg.Tactic.Guides
import Egg.Tactic.Rewrites
import Egg.Tactic.Trace
import Lean
Expand All @@ -21,6 +22,11 @@ private structure Goal where
type : Congr
base? : Option FVarId

private def Goal.tcProjTargets (goal : Goal) : Array TcProjTarget := #[
{ expr := goal.type.lhs, src := .goal, side? := some .left },
{ expr := goal.type.rhs, src := .goal, side? := some .right }
]

private def getAmbientMVars : MetaM Explanation.AmbientMVars :=
return (← getMCtx).decls

Expand All @@ -38,10 +44,11 @@ where
else
throwError "expected goal to be of type '=' or '↔', but found:\n{← ppExpr goalType}"

private def genRewrites (goal : Goal) (rws : TSyntax `egg_rws) (cfg : Config) : TacticM Rewrites := do
private def genRewrites (goal : Goal) (rws : TSyntax `egg_rws) (guides : Guides) (cfg : Config) :
TacticM Rewrites := do
let mut rws ← Rewrites.parse cfg.betaReduceRws cfg.etaReduceRws rws
if cfg.genTcProjRws then
let tcProjTargets := #[(goal.type, Source.goal)] ++ (rws.map fun rw => (rw.toCongr, rw.src))
let tcProjTargets := goal.tcProjTargets ++ guides.tcProjTargets ++ rws.tcProjTargets
rws := rws ++ (← genTcProjReductions tcProjTargets cfg.betaReduceRws cfg.etaReduceRws)
if cfg.genTcSpecRws then
rws := rws ++ (← genTcSpecializations rws)
Expand All @@ -66,15 +73,16 @@ private def processRawExpl

open Config.Modifier (egg_cfg_mod)

elab "egg " mod:egg_cfg_mod rws:egg_rws base:(egg_base)? : tactic => do
elab "egg " mod:egg_cfg_mod rws:egg_rws base:(egg_base)? guides:(egg_guides)? : tactic => do
let goal ← getMainGoal
let mod ← Config.Modifier.parse mod
let cfg := (← Config.fromOptions).modify mod
goal.withContext do
let amb ← getAmbientMVars
let goal ← parseGoal goal base
let rws ← genRewrites goal rws cfg
let req ← Request.encoding goal.type rws #[] cfg
let amb ← getAmbientMVars
let goal ← parseGoal goal base
let guides := (← guides.mapM Guides.parseGuides).getD #[]
let rws ← genRewrites goal rws guides cfg
let req ← Request.encoding goal.type rws guides cfg
req.trace
if cfg.exitPoint == .beforeEqSat then goal.id.admit; return
let rawExpl := req.run
Expand Down
36 changes: 22 additions & 14 deletions Lean/Egg/Tactic/Explanation.lean
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ syntax ("/" num)+ : egg_subexpr_pos
syntax "#" noWs num (noWs "/" noWs num)? : egg_basic_fwd_rw_src
syntax "*" noWs num : egg_basic_fwd_rw_src

syntax "[" egg_side egg_subexpr_pos "]" : egg_tc_proj
syntax "[" (egg_side)? egg_subexpr_pos "]" : egg_tc_proj

syntax "→" : egg_tc_spec_dir
syntax "←" : egg_tc_spec_dir
syntax "<" egg_tc_spec_dir ">" : egg_tc_spec

syntax egg_basic_fwd_rw_src (egg_tc_proj)? (egg_tc_spec)? : egg_fwd_rw_src
syntax "⊢" egg_tc_proj : egg_fwd_rw_src
syntax "⊢" egg_tc_proj (egg_tc_spec)? : egg_fwd_rw_src
syntax "↣" num egg_tc_proj (egg_tc_spec)? : egg_fwd_rw_src
syntax "≡0" : egg_fwd_rw_src
syntax "≡→S" : egg_fwd_rw_src
syntax "≡S→" : egg_fwd_rw_src
Expand Down Expand Up @@ -123,18 +124,25 @@ private def parseBasicFwdRwSrc : (TSyntax `egg_basic_fwd_rw_src) → Source
| _ => unreachable!

private def parseFwdRwSrc : (TSyntax `egg_fwd_rw_src) → Source
| `(egg_fwd_rw_src|⊢[$side$pos]) => .tcProj .goal (parseSide side) (parseSubexprPos pos)
| `(egg_fwd_rw_src|≡0) => .natLit .zero
| `(egg_fwd_rw_src|≡→S) => .natLit .toSucc
| `(egg_fwd_rw_src|≡S→) => .natLit .ofSucc
| `(egg_fwd_rw_src|≡+) => .natLit .add
| `(egg_fwd_rw_src|≡-) => .natLit .sub
| `(egg_fwd_rw_src|≡*) => .natLit .mul
| `(egg_fwd_rw_src|≡^) => .natLit .pow
| `(egg_fwd_rw_src|≡/) => .natLit .div
| `(egg_fwd_rw_src|"≡%") => .natLit .mod
| `(egg_fwd_rw_src|≡η) => .eta
| `(egg_fwd_rw_src|≡β) => .beta
| `(egg_fwd_rw_src|≡0) => .natLit .zero
| `(egg_fwd_rw_src|≡→S) => .natLit .toSucc
| `(egg_fwd_rw_src|≡S→) => .natLit .ofSucc
| `(egg_fwd_rw_src|≡+) => .natLit .add
| `(egg_fwd_rw_src|≡-) => .natLit .sub
| `(egg_fwd_rw_src|≡*) => .natLit .mul
| `(egg_fwd_rw_src|≡^) => .natLit .pow
| `(egg_fwd_rw_src|≡/) => .natLit .div
| `(egg_fwd_rw_src|"≡%") => .natLit .mod
| `(egg_fwd_rw_src|≡η) => .eta
| `(egg_fwd_rw_src|≡β) => .beta
| `(egg_fwd_rw_src|⊢[$tcProjSide$tcProjPos]$[<$tcSpecDir?>]?) => Id.run do
let mut src := Source.tcProj .goal (parseSide tcProjSide) (parseSubexprPos tcProjPos)
if let some tcSpecDir := tcSpecDir? then src := .tcSpec src (parsTcSpecDir tcSpecDir)
return src
| `(egg_fwd_rw_src|↣$idx[$[$tcProjSide]?$tcProjPos]$[<$tcSpecDir?>]?) => Id.run do
let mut src := Source.tcProj (.guide idx.getNat) (tcProjSide.map parseSide) (parseSubexprPos tcProjPos)
if let some tcSpecDir := tcSpecDir? then src := .tcSpec src (parsTcSpecDir tcSpecDir)
return src
| `(egg_fwd_rw_src|$src:egg_basic_fwd_rw_src$[[$tcProjSide?$tcProjPos?]]?$[<$tcSpecDir?>]?) => Id.run do
let mut src := parseBasicFwdRwSrc src
if let some tcProjSide := tcProjSide? then
Expand Down
17 changes: 17 additions & 0 deletions Lean/Egg/Tactic/Guides.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import Egg.Core.Guides
import Lean
open Lean Elab Tactic

namespace Egg.Guides

declare_syntax_cat egg_guides
syntax " via " (term,*) : egg_guides

def parseGuides : TSyntax `egg_guides → TacticM Guides
| `(egg_guides|via $gs,*) => do
let mut guides : Guides := #[]
for g in gs.getElems, idx in [:gs.getElems.size] do
let guide ← Guide.from (← Tactic.elabTerm g none) (.guide idx)
guides := guides.push guide
return guides
| _ => unreachable!
15 changes: 9 additions & 6 deletions Lean/Egg/Tests/Groups.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ instance [Group α] : OfNat α 0 where ofNat := zero

variable [Group G] {a b : G}

-- Note: Using `@add_assoc` etc, produces `.proj` expressions.
macro "group" : tactic => `(tactic|
egg [add_assoc, zero_add, add_zero, add_left_neg, add_right_neg]
open Egg.Guides Egg.Config.Modifier in
macro "group" mod:egg_cfg_mod base:(egg_base)? guides:(egg_guides)? : tactic => `(tactic|
egg $mod [add_assoc, zero_add, add_zero, add_left_neg, add_right_neg] $[$base]? $[$guides]?
)

theorem neg_add_cancel_left : -a + (a + b) = b := by group
Expand All @@ -38,6 +38,9 @@ theorem inv_inv : -(-a) = a := by
calc _ = -(-a) + (-a + a) := by group
_ = _ := by group

/-
group via -(-a) + (-a + a), -b + -a + (a + b) + -(a + b)
-/
theorem neg_add' : -(a + b) = -b + -a := by
group via -b + -a + (a + b) + -(a + b)

-- BUG: Try adding a hypothesis. This should cause the backend to crash.
theorem inv_inv' : -(-a) = a := by
group via -(-a) + (-a + a)
3 changes: 3 additions & 0 deletions Rust/src/basic.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::time::Duration;

use egg::*;
use crate::result::*;
use crate::analysis::*;
Expand Down Expand Up @@ -47,6 +49,7 @@ pub fn explain_congr(init: String, goal: String, rw_templates: Vec<RewriteTempla

let mut runner = Runner::default()
.with_egraph(egraph)
.with_time_limit(Duration::from_secs(10))
.with_hook(move |runner| {
if let Some(path) = &viz_path {
runner.egraph.dot().to_dot(format!("{}/{}.dot", path, runner.iterations.len())).unwrap();
Expand Down

0 comments on commit d83ed78

Please sign in to comment.