diff --git a/docs/src/egraphs.md b/docs/src/egraphs.md index 9fe1d229..fa6a80a7 100644 --- a/docs/src/egraphs.md +++ b/docs/src/egraphs.md @@ -269,7 +269,7 @@ The `EGraph{E,A}` type is parametrized by the expression type `E` and the The following functions define an interface for analyses based on multiple dispatch: -* [make(g::EGraph{ExprType, AnalysisType}, n)](@ref) should take an e-node `n::VecExpr` and return a value from the analysis domain. +* [make(g::EGraph{ExprType, AnalysisType}, n, md)](@ref) should take an e-node `n::VecExpr`, and metadata `md` from an expression (possibly `noting`), and return a value from the analysis domain. * [join(x::AnalysisType, y::AnalysisType)](@ref) should return the semilattice join of `x` and `y` in the analysis domain (e.g. *given two analyses value from ENodes in the same EClass, which one should I choose?* or *how should they be merged?*).`Base.isless` must be defined. * [modify!(g::EGraph{ExprType, AnalysisType}, eclass::EClass{AnalysisType})](@ref) Can be optionally implemented. This can be used modify an EClass `egraph[eclass.id]` on-the-fly during an e-graph saturation iteration, given its analysis value, typically by adding an e-node. @@ -325,7 +325,9 @@ From the definition of an e-node, we know that children of e-nodes are always ID to e-classes in the `EGraph`. ```@example custom_analysis -function EGraphs.make(g::EGraph{ExpressionType,OddEvenAnalysis}, op, n::VecExpr) where {ExpressionType} +function EGraphs.make(g::EGraph{ExpressionType,OddEvenAnalysis}, op, n::VecExpr, md) where {ExpressionType} + # metadata `md` is not used in this instance. + v_isexpr(n) || return odd_even_base_case(op) # The e-node is not a literal value, # Let's consider only binary function call terms. diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index fe84b1ae..1a980f39 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -424,8 +424,9 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp eclass_id = find(g, eclass_id) eclass_id_key = IdKey(eclass_id) eclass = g.classes[eclass_id_key] + md = eclass.data - node_data = make(g, node) + node_data = make(g, node, md) if !isnothing(eclass.data) joined_data = join(eclass.data, node_data) @@ -469,7 +470,7 @@ end function check_analysis(g) for (id, eclass) in g.classes isnothing(eclass.data) && continue - pass = mapreduce(x -> make(g, x), (x, y) -> join(x, y), eclass) + pass = mapreduce(x -> make(g, x, x.data), (x, y) -> join(x, y), eclass) @assert eclass.data == pass end true diff --git a/test/egraphs/analysis.jl b/test/egraphs/analysis.jl index 092031f6..1266e0b4 100644 --- a/test/egraphs/analysis.jl +++ b/test/egraphs/analysis.jl @@ -15,7 +15,7 @@ Base.:(*)(a::NumberFoldAnalysis, b::NumberFoldAnalysis) = NumberFoldAnalysis(a.n Base.:(+)(a::NumberFoldAnalysis, b::NumberFoldAnalysis) = NumberFoldAnalysis(a.n + b.n) # This should be auto-generated by a macro -function EGraphs.make(g::EGraph{ExpressionType,NumberFoldAnalysis}, n::VecExpr) where {ExpressionType} +function EGraphs.make(g::EGraph{ExpressionType,NumberFoldAnalysis}, n::VecExpr, md) where {ExpressionType} h = get_constant(g, v_head(n)) v_isexpr(n) || return h isa Number ? NumberFoldAnalysis(h) : nothing if v_iscall(n) && v_arity(n) == 2 diff --git a/test/integration/cas.jl b/test/integration/cas.jl index 08a34749..23151e83 100644 --- a/test/integration/cas.jl +++ b/test/integration/cas.jl @@ -215,7 +215,7 @@ end # ex = rewrite(ex, canonical_t; clean=false) -function EGraphs.make(g::EGraph{Expr,Type}, n::VecExpr) +function EGraphs.make(g::EGraph{Expr,Type}, n::VecExpr, md) h = get_constant(g, v_head(n)) v_isexpr(n) || return (h in (:im, im) ? Complex : typeof(h)) v_iscall(n) || return (Any) diff --git a/test/tutorials/lambda_theory.jl b/test/tutorials/lambda_theory.jl index 9f74a725..8df1eea8 100644 --- a/test/tutorials/lambda_theory.jl +++ b/test/tutorials/lambda_theory.jl @@ -104,7 +104,7 @@ const LambdaAnalysis = Set{Symbol} getdata(eclass) = eclass.data -function EGraphs.make(g::EGraph{ExprType,LambdaAnalysis}, n::VecExpr) where {ExprType} +function EGraphs.make(g::EGraph{ExprType,LambdaAnalysis}, n::VecExpr, md) where {ExprType} v_isexpr(n) || return LambdaAnalysis() if v_iscall(n) h = v_head(n)