Skip to content

Commit

Permalink
Implement type class specialization
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusrossel committed Apr 12, 2024
1 parent 13ffcb8 commit 39a1128
Show file tree
Hide file tree
Showing 23 changed files with 230 additions and 148 deletions.
4 changes: 2 additions & 2 deletions C/ffi.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ structure Rewrite.Encoded where
name : String
lhs : Expression
rhs : Expression
dirs : Rewrite.Directions
dirs : Directions
abbrev Expression := String
inductive Rewrite.Directions where
inductive Directions where
| none
| forward
| backward
Expand Down
3 changes: 2 additions & 1 deletion Lean/Egg.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ import Egg.Core.Encode.EncodeM
import Egg.Core.Explanation.Basic
import Egg.Core.Explanation.Proof
import Egg.Core.Gen.TcProjs
import Egg.Core.Gen.TcSpecs
import Egg.Core.Rewrites.Basic
import Egg.Core.Rewrites.Directions
import Egg.Core.Config
import Egg.Core.Congr
import Egg.Core.Directions
import Egg.Core.MVars
import Egg.Core.Request
import Egg.Core.Source
Expand Down
2 changes: 1 addition & 1 deletion Lean/Egg/Core/Config.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ structure Encoding where

structure Gen where
genTcProjRws := true
genTcSpecRws := true
genNatLitRws := true
genEtaRw := true
genBetaRw := true
explode := true
deriving BEq

structure Backend where
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import Lean
open Lean

namespace Egg.Rewrite
namespace Egg

inductive Direction where
| forward
| backward
deriving Inhabited
deriving Inhabited, BEq, Hashable

def Direction.description : Direction → String
| .forward => "→"
| .backward => "←"

def Direction.merge : Direction → Direction → Direction
| .forward, .forward | .backward, .backward => .forward
Expand All @@ -28,6 +32,10 @@ instance : ToString Directions where
| .backward => "backward"
| .both => "both"

def contains : Directions → Direction → Bool
| .both, _ | .forward, .forward | .backward, .backward => true
| _, _ => false

-- The directions for which a given set is a superset of the other.
def satisfyingSuperset (lhs rhs : RBTree α cmp) : Directions :=
match rhs.subset lhs, lhs.subset rhs with
Expand Down
1 change: 0 additions & 1 deletion Lean/Egg/Core/Encode/EncodeM.lean
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import Egg.Core.Config
import Egg.Core.Source
import Egg.Core.Gen.Explosion
import Std.Data.List.Basic

open Lean
Expand Down
2 changes: 1 addition & 1 deletion Lean/Egg/Core/Encode/Rewrites.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ structure Rewrite.Encoded where
name : String
lhs : Expression
rhs : Expression
dirs : Rewrite.Directions
dirs : Directions

def Rewrite.encode (cfg : Config.Encoding) (rw : Rewrite) : MetaM Encoded :=
return {
Expand Down
3 changes: 1 addition & 2 deletions Lean/Egg/Core/Explanation/Basic.lean
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import Egg.Core.Source
import Egg.Core.Rewrites.Directions
import Egg.Core.Directions

open Lean
open Egg.Rewrite (Direction)

namespace Egg.Explanation

Expand Down
40 changes: 0 additions & 40 deletions Lean/Egg/Core/Gen/Explosion.lean

This file was deleted.

36 changes: 36 additions & 0 deletions Lean/Egg/Core/Gen/TcSpecs.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import Egg.Core.Rewrites.Basic
import Std.Tactic.Exact
import Lean
open Lean Meta

namespace Egg

private partial def genSpecialization (rw : Rewrite) (dir : Direction) (missing : MVarIdSet) :
MetaM (Option Rewrite) := do
let (rw, subst) ← rw.freshWithSubst (src := .tcSpec rw.src dir)
let mut missing := missing.map subst.expr.fwd.find!
let mut changed := true
while changed do
changed := false
for var in missing do
if let some inst ← findLocalDeclWithType? (← var.getType) then
var.assign (.fvar inst)
missing := missing.erase var
changed := true
let rw ← rw.instantiateMVars
return if rw.validDirs.contains dir then rw else none

private def genTcSpecializationsForRw (rw : Rewrite) : MetaM Rewrites := do
let missingOnLhs := rw.rhsMVars.tc.subtract rw.lhsMVars.tc
let missingOnRhs := rw.lhsMVars.tc.subtract rw.rhsMVars.tc
let mut specs : Rewrites := #[]
if !missingOnLhs.isEmpty then
if let some spec ← genSpecialization rw .forward missingOnLhs then
specs := specs.push spec
if !missingOnRhs.isEmpty then
if let some spec ← genSpecialization rw .backward missingOnRhs then
specs := specs.push spec
return specs

def genTcSpecializations (targets : Rewrites) : MetaM Rewrites :=
targets.foldlM (init := #[]) fun acc rw => return acc ++ (← genTcSpecializationsForRw rw)
113 changes: 98 additions & 15 deletions Lean/Egg/Core/MVars.lean
Original file line number Diff line number Diff line change
@@ -1,36 +1,44 @@
import Egg.Lean
import Lean
open Lean
open Lean Meta

namespace Egg

structure MVars where
expr : MVarIdSet := ∅
lvl : LMVarIdSet := ∅
-- A subset of `expr` which tracks the mvars whose type is a type class.
tc : MVarIdSet := ∅

private def MVars.insertExpr (mvars : MVars) (id : MVarId) : MetaM MVars := do
let isClass := (← isClass? (← id.getType)).isSome
return { mvars with
expr := mvars.expr.insert id
tc := if isClass then mvars.tc.insert id else mvars.tc
}

private structure MVarCollectionState where
visitedExprs : ExprSet := {}
visitedLvls : LevelSet := {}
mvars : MVars := {}

private partial def collectMVars : Expr → MVarCollectionState → MVarCollectionState
private partial def collectMVars : Expr → MVarCollectionState → MetaM MVarCollectionState
| .mvar id => visitMVar id
| .const _ lvls => visitConst lvls
| .sort lvl => visitSort lvl
| .const _ lvls => (return visitConst lvls ·)
| .sort lvl => (return visitSort lvl ·)
| .proj _ _ e | .mdata _ e => visit e
| .forallE _ e₁ e₂ _ | .lam _ e₁ e₂ _ | .app e₁ e₂ => visit e₁ visit e₂
| .letE _ e₁ e₂ e₃ _ => visit e₁ visit e₂ visit e₃
| _ => id
| .forallE _ e₁ e₂ _ | .lam _ e₁ e₂ _ | .app e₁ e₂ => visit e₁ >=> visit e₂
| .letE _ e₁ e₂ e₃ _ => visit e₁ >=> visit e₂ >=> visit e₃
| _ => pure
where
visit (e : Expr) (s : MVarCollectionState) : MVarCollectionState :=
visit (e : Expr) (s : MVarCollectionState) : MetaM MVarCollectionState :=
if !e.hasMVar || s.visitedExprs.contains e then
s
return s
else
collectMVars e { s with visitedExprs := s.visitedExprs.insert e }

visitMVar (id : MVarId) (s : MVarCollectionState) : MVarCollectionState := { s with
mvars.expr := s.mvars.expr.insert id
}
visitMVar (id : MVarId) (s : MVarCollectionState) : MetaM MVarCollectionState :=
return { s with mvars := ← s.mvars.insertExpr id }

visitConst (lvls : List Level) (s : MVarCollectionState) : MVarCollectionState := Id.run do
let mut s := s
Expand All @@ -53,9 +61,84 @@ where
visitedLvls := s.visitedLvls.insert lvl
}

def MVars.collect (e : Expr) : MVars :=
collectMVars e {} |>.mvars
namespace MVars

def collect (e : Expr) : MetaM MVars :=
MVarCollectionState.mvars <$> collectMVars e {}

def MVars.merge (vars₁ vars₂ : MVars) : MVars where
def merge (vars₁ vars₂ : MVars) : MVars where
expr := vars₁.expr.merge vars₂.expr
lvl := vars₁.lvl.merge vars₂.lvl

protected structure Subst.Expr where
fwd : HashMap MVarId MVarId := ∅
bwd : HashMap MVarId MVarId := ∅

protected abbrev Subst.Lvl := HashMap LMVarId LMVarId

structure Subst where
expr : Subst.Expr := {}
lvl : Subst.Lvl := ∅

def Subst.apply (subst : Subst) (e : Expr) : Expr :=
e.replace replaceExpr
where
replaceExpr : Expr → Option Expr
| .mvar id => subst.expr.fwd.find? id >>= (Expr.mvar ·)
| .sort lvl => Expr.sort <| lvl.replace replaceLvl
| .const name lvls => Expr.const name <| lvls.map (·.replace replaceLvl)
| _ => none

replaceLvl : Level → Option Level
| .mvar id => subst.lvl.find? id >>= (Level.mvar ·)
| _ => none

def fresh (mvars : MVars) (init : Subst := {}) : MetaM (MVars × Subst) := do
let (exprVars, exprSubst) ← freshExprs mvars.expr init.expr
let (lvlVars, lvlSubst) ← freshLvls mvars.lvl init.lvl
let subst := { expr := exprSubst, lvl := lvlSubst }
assignFreshExprMVarTypes exprVars subst
return ({ expr := exprVars, lvl := lvlVars }, subst)
where
freshExprs (src : MVarIdSet) (subst : Subst.Expr) : MetaM (MVarIdSet × Subst.Expr) := do
let mut vars : MVarIdSet := {}
let mut subst := subst
for var in src do
if let some f := subst.fwd.find? var then
vars := vars.insert f
else
-- Note: As the type of an mvar may also contain mvars, we also have to replace mvars with
-- their fresh counterpart *in the type*. We can only do this once we know the fresh
-- counterpart for each mvar, so we postpone the type assignment.
let f ← mkFreshExprMVar none
subst := {
fwd := subst.fwd.insert var f.mvarId!
bwd := subst.bwd.insert f.mvarId! var
}
vars := vars.insert f.mvarId!
return (vars, subst)

freshLvls (src : LMVarIdSet) (subst : Subst.Lvl) : MetaM (LMVarIdSet × Subst.Lvl) := do
let mut vars : LMVarIdSet := {}
let mut subst := subst
for var in src do
if let some f := subst.find? var then
vars := vars.insert f
else
let f ← mkFreshLevelMVar
subst := subst.insert var f.mvarId!
vars := vars.insert f.mvarId!
return (vars, subst)

assignFreshExprMVarTypes (vars : MVarIdSet) (subst : Subst) : MetaM Unit := do
for var in vars do
let srcType ← (subst.expr.bwd.find! var).getType
let freshType := subst.apply srcType
var.setType freshType

def removeAssigned (mvars : MVars) : MetaM MVars := do
return {
expr := ← mvars.expr.filterM fun var => return !(← var.isAssigned)
lvl := ← mvars.lvl.filterM fun var => return !(← isLevelMVarAssigned var)
tc := ← mvars.tc.filterM fun var => return !(← var.isAssigned)
}
1 change: 0 additions & 1 deletion Lean/Egg/Core/Request.lean
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import Egg.Core.Encode.Rewrites
import Egg.Core.Config
import Egg.Core.Gen.Explosion
import Egg.Core.Explanation.Basic
import Egg.Core.Rewrites.Basic
open Lean
Expand Down
Loading

0 comments on commit 39a1128

Please sign in to comment.