Skip to content

Commit

Permalink
fix tests by passing md to EGraphs.make
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx committed Sep 17, 2024
1 parent d4c4c25 commit 6d008b9
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 7 deletions.
6 changes: 4 additions & 2 deletions docs/src/egraphs.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ The `EGraph{E,A}` type is parametrized by the expression type `E` and the

The following functions define an interface for analyses based on multiple dispatch:

* [make(g::EGraph{ExprType, AnalysisType}, n)](@ref) should take an e-node `n::VecExpr` and return a value from the analysis domain.
* [make(g::EGraph{ExprType, AnalysisType}, n, md)](@ref) should take an e-node `n::VecExpr`, and metadata `md` from an expression (possibly `noting`), and return a value from the analysis domain.
* [join(x::AnalysisType, y::AnalysisType)](@ref) should return the semilattice join of `x` and `y` in the analysis domain (e.g. *given two analyses value from ENodes in the same EClass, which one should I choose?* or *how should they be merged?*).`Base.isless` must be defined.
* [modify!(g::EGraph{ExprType, AnalysisType}, eclass::EClass{AnalysisType})](@ref) Can be optionally implemented. This can be used modify an EClass `egraph[eclass.id]` on-the-fly during an e-graph saturation iteration, given its analysis value, typically by adding an e-node.

Expand Down Expand Up @@ -325,7 +325,9 @@ From the definition of an e-node, we know that children of e-nodes are always ID
to e-classes in the `EGraph`.

```@example custom_analysis
function EGraphs.make(g::EGraph{ExpressionType,OddEvenAnalysis}, op, n::VecExpr) where {ExpressionType}
function EGraphs.make(g::EGraph{ExpressionType,OddEvenAnalysis}, op, n::VecExpr, md) where {ExpressionType}
# metadata `md` is not used in this instance.
v_isexpr(n) || return odd_even_base_case(op)
# The e-node is not a literal value,
# Let's consider only binary function call terms.
Expand Down
5 changes: 3 additions & 2 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,9 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp
eclass_id = find(g, eclass_id)
eclass_id_key = IdKey(eclass_id)
eclass = g.classes[eclass_id_key]
md = eclass.data

node_data = make(g, node)
node_data = make(g, node, md)
if !isnothing(eclass.data)
joined_data = join(eclass.data, node_data)

Expand Down Expand Up @@ -469,7 +470,7 @@ end
function check_analysis(g)
for (id, eclass) in g.classes
isnothing(eclass.data) && continue
pass = mapreduce(x -> make(g, x), (x, y) -> join(x, y), eclass)
pass = mapreduce(x -> make(g, x, x.data), (x, y) -> join(x, y), eclass)
@assert eclass.data == pass
end
true
Expand Down
2 changes: 1 addition & 1 deletion test/egraphs/analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Base.:(*)(a::NumberFoldAnalysis, b::NumberFoldAnalysis) = NumberFoldAnalysis(a.n
Base.:(+)(a::NumberFoldAnalysis, b::NumberFoldAnalysis) = NumberFoldAnalysis(a.n + b.n)

# This should be auto-generated by a macro
function EGraphs.make(g::EGraph{ExpressionType,NumberFoldAnalysis}, n::VecExpr) where {ExpressionType}
function EGraphs.make(g::EGraph{ExpressionType,NumberFoldAnalysis}, n::VecExpr, md) where {ExpressionType}
h = get_constant(g, v_head(n))
v_isexpr(n) || return h isa Number ? NumberFoldAnalysis(h) : nothing
if v_iscall(n) && v_arity(n) == 2
Expand Down
2 changes: 1 addition & 1 deletion test/integration/cas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ end
# ex = rewrite(ex, canonical_t; clean=false)


function EGraphs.make(g::EGraph{Expr,Type}, n::VecExpr)
function EGraphs.make(g::EGraph{Expr,Type}, n::VecExpr, md)
h = get_constant(g, v_head(n))
v_isexpr(n) || return (h in (:im, im) ? Complex : typeof(h))
v_iscall(n) || return (Any)
Expand Down
2 changes: 1 addition & 1 deletion test/tutorials/lambda_theory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ const LambdaAnalysis = Set{Symbol}

getdata(eclass) = eclass.data

function EGraphs.make(g::EGraph{ExprType,LambdaAnalysis}, n::VecExpr) where {ExprType}
function EGraphs.make(g::EGraph{ExprType,LambdaAnalysis}, n::VecExpr, md) where {ExprType}
v_isexpr(n) || return LambdaAnalysis()
if v_iscall(n)
h = v_head(n)
Expand Down

0 comments on commit 6d008b9

Please sign in to comment.