Skip to content

Commit

Permalink
Add backend support for guides
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusrossel committed Apr 12, 2024
1 parent e1916f5 commit f6cc768
Show file tree
Hide file tree
Showing 17 changed files with 94 additions and 29 deletions.
33 changes: 23 additions & 10 deletions C/ffi.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ typedef struct rewrite {
/*
structure Rewrite.Encoded where
name : String
lhs : Expression
rhs : Expression
lhs : String
rhs : String
dirs : Directions
abbrev Expression := String
inductive Directions where
| none
| forward
Expand All @@ -50,6 +48,16 @@ rewrite* rewrites_from_lean_obj(lean_obj_arg rws) {
return rust_rws;
}

const char** guides_from_lean_obj(lean_obj_arg guides) {
lean_object** guides_c_ptr = lean_array_cptr(guides);
size_t guides_count = lean_array_size(guides);
const char** rust_guides = malloc(guides_count * sizeof(const char*));
for (int idx = 0; idx < guides_count; idx++) {
rust_guides[idx] = lean_string_cstr(guides_c_ptr[idx]);
}
return rust_guides;
}

typedef struct config {
_Bool optimize_expl;
_Bool gen_nat_lit_rws;
Expand Down Expand Up @@ -92,15 +100,18 @@ extern char* egg_explain_congr(
const char* goal,
rewrite* rws,
size_t rws_count,
const char** guides,
size_t guides_count,
config cfg,
const char* viz_path
);

/*
structure Egg.Request where
lhs : Expression
rhs : Expression
rws : Rewrites.Encoded
lhs : String
rhs : String
rws : Array Rewrite.Encoded
guides : Array String
vizPath : String
cfg : Request.Config
*/
Expand All @@ -109,10 +120,12 @@ lean_obj_res run_egg_request(lean_obj_arg req) {
const char* rhs = lean_string_cstr(lean_ctor_get(req, 1));
rewrite* rws = rewrites_from_lean_obj(lean_ctor_get(req, 2));
size_t rws_count = lean_array_size(lean_ctor_get(req, 2));
const char* viz_path = lean_string_cstr(lean_ctor_get(req, 3));
config cfg = config_from_lean_obj(lean_ctor_get(req, 4));
const char** guides = guides_from_lean_obj(lean_ctor_get(req, 3));
size_t guides_count = lean_array_size(lean_ctor_get(req, 3));
const char* viz_path = lean_string_cstr(lean_ctor_get(req, 4));
config cfg = config_from_lean_obj(lean_ctor_get(req, 5));

char* result = egg_explain_congr(lhs, rhs, rws, rws_count, cfg, viz_path);
char* result = egg_explain_congr(lhs, rhs, rws, rws_count, guides, guides_count, cfg, viz_path);
free(rws);

return lean_mk_string(result);
Expand Down
5 changes: 4 additions & 1 deletion Lean/Egg.lean
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import Egg.Core.Encode.Basic
import Egg.Core.Encode.EncodeM
import Egg.Core.Encode.Guides
import Egg.Core.Encode.Rewrites
import Egg.Core.Explanation.Basic
import Egg.Core.Explanation.Proof
import Egg.Core.Gen.TcProjs
import Egg.Core.Gen.TcSpecs
import Egg.Core.Rewrites.Basic
import Egg.Core.Config
import Egg.Core.Congr
import Egg.Core.Directions
import Egg.Core.Guides
import Egg.Core.MVars
import Egg.Core.Request
import Egg.Core.Rewrites
import Egg.Core.Source
import Egg.Tactic.Config.Modifier
import Egg.Tactic.Config.Option
Expand Down
13 changes: 13 additions & 0 deletions Lean/Egg/Core/Encode/Guides.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import Egg.Core.Guides
import Egg.Core.Encode.Basic
import Lean
open Lean

namespace Egg

abbrev Guide.Encoded := Expression

abbrev Guides.Encoded := Array Guide.Encoded

def Guides.encode (guides : Guides) (cfg : Config.Encoding) : MetaM Guides.Encoded :=
guides.mapM fun guide => Egg.encode guide.expr guide.src cfg
2 changes: 1 addition & 1 deletion Lean/Egg/Core/Encode/Rewrites.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import Egg.Core.Encode.Basic
import Egg.Core.Rewrites.Basic
import Egg.Core.Rewrites
import Lean
open Lean

Expand Down
2 changes: 1 addition & 1 deletion Lean/Egg/Core/Explanation/Proof.lean
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import Egg.Core.Explanation.Basic
import Egg.Core.Explanation.Congr
import Egg.Core.Rewrites.Basic
import Egg.Core.Rewrites
open Lean Meta

-- TODO: Simplify tracing by adding `MessageData` instances for relevant types.
Expand Down
2 changes: 1 addition & 1 deletion Lean/Egg/Core/Gen/TcProjs.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Egg.Core.Rewrites.Basic
import Egg.Core.Rewrites
import Lean
open Lean Meta

Expand Down
2 changes: 1 addition & 1 deletion Lean/Egg/Core/Gen/TcSpecs.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Egg.Core.Rewrites.Basic
import Egg.Core.Rewrites
import Std.Tactic.Exact
import Lean
open Lean Meta
Expand Down
11 changes: 11 additions & 0 deletions Lean/Egg/Core/Guides.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import Egg.Core.Source
import Lean
open Lean

namespace Egg

structure Guide where
expr : Expr
src : Source

abbrev Guides := Array Guide
14 changes: 9 additions & 5 deletions Lean/Egg/Core/Request.lean
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import Egg.Core.Encode.Rewrites
import Egg.Core.Encode.Guides
import Egg.Core.Config
import Egg.Core.Explanation.Basic
import Egg.Core.Rewrites.Basic
import Egg.Core.Rewrites
open Lean

namespace Egg.Request
Expand Down Expand Up @@ -35,14 +36,17 @@ structure _root_.Egg.Request where
lhs : Expression
rhs : Expression
rws : Rewrites.Encoded
guides : Guides.Encoded
vizPath : String
cfg : Request.Config

def encoding (goal : Congr) (rws : Rewrites) (cfg : Egg.Config) : MetaM Request := do
def encoding (goal : Congr) (rws : Rewrites) (guides : Guides) (cfg : Egg.Config) :
MetaM Request := do
return {
lhs := ← encode goal.lhs .goal cfg.toEncoding
rhs := ← encode goal.rhs .goal cfg.toEncoding
rws := ← rws.encode cfg.toEncoding
lhs := ← encode goal.lhs .goal cfg.toEncoding
rhs := ← encode goal.rhs .goal cfg.toEncoding
rws := ← rws.encode cfg.toEncoding
guides := ← guides.encode cfg.toEncoding
vizPath := cfg.vizPath.getD ""
cfg
}
Expand Down
File renamed without changes.
2 changes: 2 additions & 0 deletions Lean/Egg/Core/Source.lean
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ inductive Source.NatLit where

inductive Source where
| goal
| guide (idx : Nat)
| explicit (idx : Nat) (eqn? : Option Nat)
| star (id : FVarId)
| tcProj (src : Source) (side : Side) (pos : SubExpr.Pos)
Expand All @@ -56,6 +57,7 @@ def NatLit.description : Source.NatLit → String

def description : Source → String
| goal => "⊢"
| guide idx => s!"↣{idx}"
| explicit idx none => s!"#{idx}"
| explicit idx (some eqn) => s!"#{idx}/{eqn}"
| star id => s!"*{id.uniqueIdx!}"
Expand Down
2 changes: 1 addition & 1 deletion Lean/Egg/Tactic/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ elab "egg " mod:egg_cfg_mod rws:egg_rws base:(egg_base)? : tactic => do
let amb ← getAmbientMVars
let goal ← parseGoal goal base
let rws ← genRewrites goal rws cfg
let req ← Request.encoding goal.type rws cfg
let req ← Request.encoding goal.type rws #[] cfg
req.trace
if cfg.exitPoint == .beforeEqSat then goal.id.admit; return
let rawExpl := req.run
Expand Down
2 changes: 1 addition & 1 deletion Lean/Egg/Tactic/Rewrites.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Egg.Core.Rewrites.Basic
import Egg.Core.Rewrites
import Lean

open Lean Meta Elab Tactic
Expand Down
12 changes: 8 additions & 4 deletions Lean/Egg/Tests/Groups.lean
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ theorem add_neg_cancel_left : a + (-a + b) = b := by group
theorem neg_zero : -(0 : G) = 0 := by group

theorem neg_add : -(a + b) = -b + -a := by
calc -(a + b) = -b + -a + (a + b) + -(a + b) := by group
_ = -b + -a := by group
calc _ = -b + -a + (a + b) + -(a + b) := by group
_ = _ := by group

theorem inv_inv : -(-a) = a := by
calc -(-a) = -(-a) + (-a + a) := by group
_ = a := by group
calc _ = -(-a) + (-a + a) := by group
_ = _ := by group

/-
group via -(-a) + (-a + a), -b + -a + (a + b) + -(a + b)
-/
9 changes: 7 additions & 2 deletions Rust/src/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub struct Config {
trace_bvar_correction: bool,
}

pub fn explain_congr(init: String, goal: String, rw_templates: Vec<RewriteTemplate>, cfg: Config, viz_path: Option<String>) -> Res<String> {
pub fn explain_congr(init: String, goal: String, rw_templates: Vec<RewriteTemplate>, guides: Vec<String>, cfg: Config, viz_path: Option<String>) -> Res<String> {
init_enabled_trace_groups(cfg.trace_substitutions, cfg.trace_bvar_correction);

let mut egraph: LeanEGraph = Default::default();
Expand All @@ -30,7 +30,12 @@ pub fn explain_congr(init: String, goal: String, rw_templates: Vec<RewriteTempla
let goal_expr = goal.parse().map_err(|e : RecExprParseError<_>| Error::Goal(e.to_string()))?;
let init_id = egraph.add_expr(&init_expr);
let goal_id = egraph.add_expr(&goal_expr);


for guide in guides {
let expr = guide.parse().map_err(|e : RecExprParseError<_>| Error::Guide(e.to_string()))?;
egraph.add_expr(&expr);
}

let mut rws;
match templates_to_rewrites(rw_templates, cfg.block_invalid_matches, cfg.shift_captured_bvars) {
Ok(r) => rws = r,
Expand Down
10 changes: 9 additions & 1 deletion Rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ pub extern "C" fn egg_explain_congr(
goal_str_ptr: *const c_char,
rws_ptr: *const CRewrite,
rws_count: usize,
guides_ptr: *const *const c_char,
guides_count: usize,
cfg: Config,
viz_path_ptr: *const c_char
) -> *const c_char {
Expand All @@ -82,6 +84,7 @@ pub extern "C" fn egg_explain_congr(
let goal = String::from_utf8_lossy(goal_c_str.to_bytes()).to_string();
assert!(rws_ptr != null());
let c_rws = unsafe { std::slice::from_raw_parts(rws_ptr, rws_count) };
let c_guides = unsafe { std::slice::from_raw_parts(guides_ptr, guides_count) };

// Note: The `into_raw`s below are important, as otherwise Rust deallocates the string.
// TODO: I think this is a memory leak right now.
Expand All @@ -93,11 +96,16 @@ pub extern "C" fn egg_explain_congr(
}
let rw_templates = rw_templates.unwrap();

let guides = c_guides.iter().map(|&guide_c_str| {
let c_str = unsafe { CStr::from_ptr(guide_c_str) };
String::from_utf8_lossy(c_str.to_bytes()).to_string()
}).collect();

let viz_path_c_str = unsafe { CStr::from_ptr(viz_path_ptr) };
let raw_viz_path = String::from_utf8_lossy(viz_path_c_str.to_bytes()).to_string();
let viz_path = if raw_viz_path.is_empty() { None } else { Some(raw_viz_path) };

let expl = explain_congr(init, goal, rw_templates, cfg, viz_path);
let expl = explain_congr(init, goal, rw_templates, guides, cfg, viz_path);
if let Err(expl_err) = expl {
let rws_err_c_str = CString::new(expl_err.to_string()).expect("conversion of error message to C-string failed");
return rws_err_c_str.into_raw()
Expand Down
2 changes: 2 additions & 0 deletions Rust/src/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pub enum Error {
Init(String),
Goal(String),
Guide(String),
Rewrite(String),
Failed,
}
Expand All @@ -12,6 +13,7 @@ impl ToString for Error {
match self {
Error::Init(s) => format!("⚡️ {s}"),
Error::Goal(s) => format!("⚡️ {s}"),
Error::Guide(s) => format!("⚡️ {s}"),
Error::Rewrite(s) => format!("⚡️ {s}"),
Error::Failed => "".to_string(),
}
Expand Down

0 comments on commit f6cc768

Please sign in to comment.