Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT]: Expression proof #226

Open
wants to merge 3 commits into
base: ale/3.0-proof
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "3.0.0"
[deps]
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Expand All @@ -19,6 +20,7 @@ Plotting = ["GraphViz"]
[compat]
AutoHashEquals = "2.1.0"
DocStringExtensions = "0.8, 0.9"
JSON = "0.21.4"
Reexport = "0.2, 1"
TermInterface = "2.0"
TimerOutputs = "0.5"
Expand Down
3 changes: 3 additions & 0 deletions src/EGraphs/EGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,7 @@ using .Schedulers
include("saturation.jl")
export SaturationParams, saturate!

include("exprproof.jl")
export PositionedProof, find_node_proof, detailed_dict

end
7 changes: 7 additions & 0 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@ function Base.show(io::IO, g::EGraph)
end


function print_nodes(g::EGraph)
for (i, node) in enumerate(g.nodes)
println("[$i] => ", to_expr(g, node))
end
end
export print_nodes

function print_proof(g::EGraph)
# Print memo
println("explain_find:")
Expand Down
105 changes: 105 additions & 0 deletions src/EGraphs/exprproof.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
export PositionedProof, find_node_proof, detailed_dict


mutable struct PositionedProof
"""
Positioned proof is a structure that keeps track of where we apply proofs to in larger expressions.
"""
proof::Vector{ProofNode}
children::Vector{PositionedProof}
# TODO: Track what is matched
end

function detailed_dict(pc::ProofConnection, g::EGraph)
return Dict(
"next" => to_expr(g, g.nodes[pc.next]), # TODO: node should be unfolded, i.e., subexpressions should be exprs not node ids
"current" => to_expr(g, g.nodes[pc.current]), # TODO: node should be unfolded, i.e., subexpressions should be exprs not node ids
"justification" => pc.justification # TODO: Change to the rules name + params
)
end

function detailed_dict(pn::ProofNode, g::EGraph)
return Dict(
"existence_node" => to_expr(g, g.nodes[pn.existence_node]), # TODO: node should be unfolded, i.e., subexpressions should be exprs not node ids
"parent_connection" => detailed_dict(pn.parent_connection,g),
"neighbours" => map(x -> detailed_dict(x, g), pn.neighbours)
)
end

function detailed_dict(pp::PositionedProof, g::EGraph)
return Dict(
"proof" => map(x -> detailed_dict(x, g), pp.proof),
"children" => map(x -> detailed_dict(x, g), pp.children)
)
end

Base.show(io::IO, pp::PositionedProof) = begin
println(io, "PositionedProof(")
println(io, " proof = ", pp.proof)
println(io, " children = [")
for child in pp.children
show(io, child)
end
println(io, " ]")
println(io, ")")
end


function find_node_proof(g::EGraph, node1::Id, node2::Id)::Union{Tuple{PositionedProof, PositionedProof}, Nothing}
# Proof search that can deal with expressions, too.

# Idea:

# Walk expr trees

# For each node:
# If has flat proof, proof to leader
# Else, recursively unfold

# If no proof found for subexpr, return nothing

# Issues: how to relate expressions?
# Especially if different Size
# e.g. a*(b+c) = ab+bc (which is different size AST)
# bigger problem comes when a=z then z*(b+c) = ab+bc

# So I guess the way we should go about it is go to base terms, rewrite to leader



# Idea: rewrite both sides to "normal forms" and concat
# TODO: This is definetely suboptimal and should be optimized
# LCA?
leader1, nfproof1 = rewrite_to_normal_form(g, node1)
leader2, nfproof2 = rewrite_to_normal_form(g, node2)
println("==========")
println(leader1)
println(leader2)
println(g)
println(nfproof1)
println(nfproof2)
if leader1 != leader2
return nothing
end
return (nfproof1, nfproof2)


end

#
function rewrite_to_normal_form(g::EGraph, node::Id)::Tuple{Id,PositionedProof}
# Start off by rewriting node to leader
lp = rewrite_to_leader(g.proof, node)
leader = lp.leader
leader_proof = lp.proof

expr = g.nodes[leader]
proof = PositionedProof(leader_proof, [])
sizehint!(proof.children, v_arity(expr))
# Do we want to do this before or after tthe leader proof?
for child in v_children(expr)
_, child_proof = rewrite_to_normal_form(g, child)
push!(proof.children, child_proof)
end
return (leader, proof)
end
37 changes: 33 additions & 4 deletions src/EGraphs/proof.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
export ProofConnection, ProofNode, EGraphProof, find_flat_proof

export ProofConnection, ProofNode, EGraphProof, find_flat_proof, rewrite_to_leader
mutable struct ProofConnection
"""
Justification can be
Expand All @@ -24,6 +23,7 @@ end


mutable struct ProofNode
# TODO: Explain
existence_node::Id
# TODO is this the parent in the unionfind?
parent_connection::ProofConnection
Expand Down Expand Up @@ -79,6 +79,7 @@ function make_leader(proof::EGraphProof, node::Id)::Bool
true
end


function Base.union!(proof::EGraphProof, node1::Id, node2::Id, rule_idx::Int)
# TODO maybe should have extra argument called `rhs_new` in egg that is true when called from
# application of rules where the instantiation of the rhs creates new e-classes
Expand Down Expand Up @@ -106,7 +107,11 @@ end
@inline isroot(pn::ProofNode) = isroot(pn.parent_connection)
@inline isroot(pc::ProofConnection) = pc.current === pc.next

function find_flat_proof(proof::EGraphProof, node1::Id, node2::Id)




function find_flat_proof(proof::EGraphProof, node1::Id, node2::Id)::Vector{ProofNode}
# We're doing a lowest common ancestor search.
# We cache the IDs we have seen
seen_set = Set{Id}()
Expand All @@ -117,6 +122,9 @@ function find_flat_proof(proof::EGraphProof, node1::Id, node2::Id)
# No existence_node would ever have id 0
lca = UInt(0)
curr = proof.explain_find[node1]
if (node1 == node2)
return [curr]
end

# Walk up to the root
while true
Expand Down Expand Up @@ -155,4 +163,25 @@ function find_flat_proof(proof::EGraphProof, node1::Id, node2::Id)
# TODO maybe reverse
append!(ret, walk_from2)
ret
end
end

struct LeaderProof
leader::Id
proof::Vector{ProofNode}
end

function rewrite_to_leader(proof::EGraphProof, node::Id)::LeaderProof
# Returns the leader of e-class and a proof to transform node into said leader
curr_proof = proof.explain_find[node]
proofs = []
final_id = node
if curr_proof.parent_connection.current == curr_proof.parent_connection.next
return LeaderProof(node, [curr_proof]) # Special case to report congruence
end
while curr_proof.parent_connection.current != curr_proof.parent_connection.next
push!(proofs, curr_proof)
final_id = curr_proof.parent_connection.next
curr_proof = proof.explain_find[curr_proof.parent_connection.next]
end
return LeaderProof(final_id, proofs)
end
6 changes: 3 additions & 3 deletions src/Rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ Base.:(==)(a::RewriteRule, b::RewriteRule) = a.op == b.op && a.left == b.left &&

function Base.show(io::IO, r::RewriteRule)
if r.op == (|>) # Is dynamic rule, replace with =>
print(io, :($(r.left) => $(r.rhs_original)))
print(io, :($(r.left) => $(r.rhs_original) : $(r.name)))
else
print(io, :($(nameof(r.op))($(r.left), $(r.right))))
print(io, :($(nameof(r.op))($(r.left), $(r.right)) : $(r.name)))
end

if !isempty(r.name)
Expand Down Expand Up @@ -153,7 +153,7 @@ instantiate(_, pat::Union{PatVar,PatSegment}, bindings) = bindings[pat.idx]
"Inverts the direction of a rewrite rule, swapping the LHS and the RHS"
function Base.inv(r::RewriteRule)
RewriteRule(
name = r.name,
name = r.name + "-inv",
op = r.op,
left = r.right,
right = r.left,
Expand Down
41 changes: 38 additions & 3 deletions test/egraphs/proof.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
using Metatheory, Test
using Metatheory.Library
import JSON

g = EGraph(; proof = true)

id_a = addexpr!(g, :a)
println(find_flat_proof(g.proof, id_a, id_a))
@test length(find_flat_proof(g.proof, id_a, id_a)) == 1

# print_proof(g)

Expand Down Expand Up @@ -31,10 +35,41 @@ id_d = addexpr!(g, :d)

union!(g, id_a, id_d, 3)
print_proof(g)

println(find_flat_proof(g.proof, id_c, id_d))
# Takes 4 steps
@test length(find_flat_proof(g.proof, id_a, id_d)) == 4
@test length(find_flat_proof(g.proof, id_c, id_d)) == 3

# TODO: Why doesn't d have a its leader
for id in [id_a, id_b, id_c, id_d]
leader = rewrite_to_leader(g.proof, id)
@test leader.leader == id_d
@test length(leader.proof) == length(find_flat_proof(g.proof, id, id_a))
end



id_e = addexpr!(g, :e)
@test isempty(find_flat_proof(g.proof, id_a, id_e))
@test isempty(find_flat_proof(g.proof, id_a, id_e))

id_z = addexpr!(g, :z)

comm_monoid = @commutative_monoid (*) 1

fold_mul = @theory begin
~a::Number * ~b::Number => ~a * ~b
end

ex = :(a * b * 4 * z)
id_ex = addexpr!(g, ex)
ex_to = :(d * c * 4 * z)
id_ex_to = addexpr!(g, ex_to)
print_nodes(g)

println(pretty_dict(g))
prf = find_node_proof(g, id_ex, id_ex_to)
if prf === nothing
println("No proof")
else
println(JSON.json(detailed_dict(prf[1], g)))
println(JSON.json(detailed_dict(prf[2], g)))
end