Skip to content

Commit ce507a7

Browse files
authored
Refactor generated function implementation to provide bindings/method (JuliaLang#57230)
This PR refactors the generated function implementation in multiple ways: 1. Rather than allocating a new LineNumber node to pass to the generator, we just pass the original method from which this LineNumberNode was constructed. This has been a bit of a longer-standing annoyance of mine, since the generator needs to know properties of the original method to properly interpret the return value from the generator, but this information was only available on the C side. 2. Move the handling of `Expr` returns fully into Julia. Right not things were a bit split with the julia code post-processing an `Expr` return, but then handing it back to C for lowering. By moving it fully into Julia, we can keep the C-side interface simpler by always getting a `CodeInfo`. With these refactorings done, amend the post-processing code to provide binding edges for `Expr` returns. Ordinarily, bindings in lowered code do not need edges, because we will scan the lowered code of the method to find them. However, generated functions are different, because we do not in general have the lowered code available. To still give them binding edges, we simply scan through the post-lowered code and all of the bindings we find into the edges array. I will note that both of these will require minor adjustments to `@generated` functions that use the CodeInfo interface (N.B.: this interface is not considered stable and we've broken it in almost every release so far). In particular, the following adjustments need to be made: 1. Adjusting the `source` argument to the new `Method` ABI 2. If necessary, adding any edges that correspond to GlobalRefs used - the code will treat the returned CodeInfo mostly opaquely and (unlike in the `Expr` case) will not automatically compute these edges.
1 parent 12698af commit ce507a7

10 files changed

+91
-83
lines changed

base/Base_compiler.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -234,13 +234,13 @@ include("abstractarray.jl")
234234
include("baseext.jl")
235235

236236
include("c.jl")
237-
include("ntuple.jl")
238237
include("abstractset.jl")
239238
include("bitarray.jl")
240239
include("bitset.jl")
241240
include("abstractdict.jl")
242241
include("iddict.jl")
243242
include("idset.jl")
243+
include("ntuple.jl")
244244
include("iterators.jl")
245245
using .Iterators: zip, enumerate, only
246246
using .Iterators: Flatten, Filter, product # for generators

base/boot.jl

-21
Original file line numberDiff line numberDiff line change
@@ -777,27 +777,6 @@ struct GeneratedFunctionStub
777777
spnames::SimpleVector
778778
end
779779

780-
# invoke and wrap the results of @generated expression
781-
function (g::GeneratedFunctionStub)(world::UInt, source::LineNumberNode, @nospecialize args...)
782-
# args is (spvals..., argtypes...)
783-
body = g.gen(args...)
784-
file = source.file
785-
file isa Symbol || (file = :none)
786-
lam = Expr(:lambda, Expr(:argnames, g.argnames...).args,
787-
Expr(:var"scope-block",
788-
Expr(:block,
789-
source,
790-
Expr(:meta, :push_loc, file, :var"@generated body"),
791-
Expr(:return, body),
792-
Expr(:meta, :pop_loc))))
793-
spnames = g.spnames
794-
if spnames === svec()
795-
return lam
796-
else
797-
return Expr(Symbol("with-static-parameters"), lam, spnames...)
798-
end
799-
end
800-
801780
# If the generator is a subtype of this trait, inference caches the generated unoptimized
802781
# code, sacrificing memory space to improve the performance of subsequent inferences.
803782
# This tradeoff is not appropriate in general cases (e.g., for `GeneratedFunctionStub`s

base/expr.jl

+43
Original file line numberDiff line numberDiff line change
@@ -1654,3 +1654,46 @@ end
16541654
function quoted(@nospecialize(x))
16551655
return is_self_quoting(x) ? x : QuoteNode(x)
16561656
end
1657+
1658+
# Implementation of generated functions
1659+
function generated_body_to_codeinfo(ex::Expr, defmod::Module, isva::Bool)
1660+
ci = ccall(:jl_expand, Any, (Any, Any), ex, defmod)
1661+
if !isa(ci, CodeInfo)
1662+
if isa(ci, Expr) && ci.head === :error
1663+
error("syntax: $(ci.args[1])")
1664+
end
1665+
error("The function body AST defined by this @generated function is not pure. This likely means it contains a closure, a comprehension or a generator.")
1666+
end
1667+
ci.isva = isva
1668+
code = ci.code
1669+
bindings = IdSet{Core.Binding}()
1670+
for i = 1:length(code)
1671+
stmt = code[i]
1672+
if isa(stmt, GlobalRef)
1673+
push!(bindings, convert(Core.Binding, stmt))
1674+
end
1675+
end
1676+
if !isempty(bindings)
1677+
ci.edges = Core.svec(bindings...)
1678+
end
1679+
return ci
1680+
end
1681+
1682+
# invoke and wrap the results of @generated expression
1683+
function (g::Core.GeneratedFunctionStub)(world::UInt, source::Method, @nospecialize args...)
1684+
# args is (spvals..., argtypes...)
1685+
body = g.gen(args...)
1686+
file = source.file
1687+
file isa Symbol || (file = :none)
1688+
lam = Expr(:lambda, Expr(:argnames, g.argnames...).args,
1689+
Expr(:var"scope-block",
1690+
Expr(:block,
1691+
LineNumberNode(Int(source.line), source.file),
1692+
Expr(:meta, :push_loc, file, :var"@generated body"),
1693+
Expr(:return, body),
1694+
Expr(:meta, :pop_loc))))
1695+
spnames = g.spnames
1696+
return generated_body_to_codeinfo(spnames === Core.svec() ? lam : Expr(Symbol("with-static-parameters"), lam, spnames...),
1697+
typename(typeof(g.gen)).module,
1698+
source.isva)
1699+
end

base/invalidation.jl

+12-10
Original file line numberDiff line numberDiff line change
@@ -93,26 +93,28 @@ function scan_edge_list(ci::Core.CodeInstance, binding::Core.Binding)
9393
end
9494

9595
function invalidate_method_for_globalref!(gr::GlobalRef, method::Method, invalidated_bpart::Core.BindingPartition, new_max_world::UInt)
96+
invalidate_all = false
97+
binding = convert(Core.Binding, gr)
9698
if isdefined(method, :source)
9799
src = _uncompressed_ir(method)
98-
binding = convert(Core.Binding, gr)
99100
old_stmts = src.code
100101
invalidate_all = should_invalidate_code_for_globalref(gr, src)
101-
for mi in specializations(method)
102-
isdefined(mi, :cache) || continue
103-
ci = mi.cache
104-
while true
105-
if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, binding))
106-
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
107-
end
108-
isdefined(ci, :next) || break
109-
ci = ci.next
102+
end
103+
for mi in specializations(method)
104+
isdefined(mi, :cache) || continue
105+
ci = mi.cache
106+
while true
107+
if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, binding))
108+
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
110109
end
110+
isdefined(ci, :next) || break
111+
ci = ci.next
111112
end
112113
end
113114
end
114115

115116
function invalidate_code_for_globalref!(gr::GlobalRef, invalidated_bpart::Core.BindingPartition, new_max_world::UInt)
117+
b = convert(Core.Binding, gr)
116118
try
117119
valid_in_valuepos = false
118120
foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable

src/jl_exported_funcs.inc

-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@
127127
XX(jl_exit_on_sigint) \
128128
XX(jl_exit_threaded_region) \
129129
XX(jl_expand) \
130-
XX(jl_expand_and_resolve) \
131130
XX(jl_expand_stmt) \
132131
XX(jl_expand_stmt_with_loc) \
133132
XX(jl_expand_with_loc) \

src/julia_internal.h

+1
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void);
722722
JL_DLLEXPORT void jl_resolve_definition_effects_in_ir(jl_array_t *stmts, jl_module_t *m, jl_svec_t *sparam_vals, jl_value_t *binding_edge,
723723
int binding_effects);
724724
JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge);
725+
JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge);
725726

726727
int get_next_edge(jl_array_t *list, int i, jl_value_t** invokesig, jl_code_instance_t **caller) JL_NOTSAFEPOINT;
727728
int set_next_edge(jl_array_t *list, int i, jl_value_t *invokesig, jl_code_instance_t *caller);

src/method.c

+9-37
Original file line numberDiff line numberDiff line change
@@ -604,8 +604,7 @@ static jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator,
604604
size_t totargs = 2 + n_sparams + def->nargs;
605605
JL_GC_PUSHARGS(gargs, totargs);
606606
gargs[0] = jl_box_ulong(world);
607-
gargs[1] = jl_box_long(def->line);
608-
gargs[1] = jl_new_struct(jl_linenumbernode_type, gargs[1], def->file);
607+
gargs[1] = (jl_value_t*)def;
609608
memcpy(&gargs[2], jl_svec_data(sparam_vals), n_sparams * sizeof(void*));
610609
memcpy(&gargs[2 + n_sparams], args, (def->nargs - def->isva) * sizeof(void*));
611610
if (def->isva)
@@ -615,23 +614,6 @@ static jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator,
615614
return code;
616615
}
617616

618-
// Lower `ex` into Julia IR, and (if it expands into a CodeInfo) resolve global-variable
619-
// references in light of the provided type parameters.
620-
// Like `jl_expand`, if there is an error expanding the provided expression, the return value
621-
// will be an error expression (an `Expr` with `error_sym` as its head), which should be eval'd
622-
// in the caller's context.
623-
JL_DLLEXPORT jl_code_info_t *jl_expand_and_resolve(jl_value_t *ex, jl_module_t *module,
624-
jl_svec_t *sparam_vals) {
625-
jl_code_info_t *func = (jl_code_info_t*)jl_expand((jl_value_t*)ex, module);
626-
JL_GC_PUSH1(&func);
627-
if (jl_is_code_info(func)) {
628-
jl_array_t *stmts = (jl_array_t*)func->code;
629-
jl_resolve_definition_effects_in_ir(stmts, module, sparam_vals, NULL, 1);
630-
}
631-
JL_GC_POP();
632-
return func;
633-
}
634-
635617
JL_DLLEXPORT jl_code_instance_t *jl_cached_uninferred(jl_code_instance_t *codeinst, size_t world)
636618
{
637619
for (; codeinst; codeinst = jl_atomic_load_relaxed(&codeinst->next)) {
@@ -703,25 +685,12 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t
703685
ex = jl_call_staged(def, generator, world, mi->sparam_vals, jl_svec_data(ttdt->parameters), jl_nparams(ttdt));
704686

705687
// do some post-processing
706-
if (jl_is_code_info(ex)) {
707-
func = (jl_code_info_t*)ex;
708-
jl_array_t *stmts = (jl_array_t*)func->code;
709-
jl_resolve_definition_effects_in_ir(stmts, def->module, mi->sparam_vals, NULL, 1);
710-
}
711-
else {
712-
// Lower the user's expression and resolve references to the type parameters
713-
func = jl_expand_and_resolve(ex, def->module, mi->sparam_vals);
714-
if (!jl_is_code_info(func)) {
715-
if (jl_is_expr(func) && ((jl_expr_t*)func)->head == jl_error_sym) {
716-
ct->ptls->in_pure_callback = 0;
717-
jl_toplevel_eval(def->module, (jl_value_t*)func);
718-
}
719-
jl_error("The function body AST defined by this @generated function is not pure. This likely means it contains a closure, a comprehension or a generator.");
720-
}
721-
// TODO: This should ideally be in the lambda expression,
722-
// but currently our isva determination is non-syntactic
723-
func->isva = def->isva;
688+
if (!jl_is_code_info(ex)) {
689+
jl_error("As of Julia 1.12, generated functions must return `CodeInfo`. See `Base.generated_body_to_codeinfo`.");
724690
}
691+
func = (jl_code_info_t*)ex;
692+
jl_array_t *stmts = (jl_array_t*)func->code;
693+
jl_resolve_definition_effects_in_ir(stmts, def->module, mi->sparam_vals, NULL, 1);
725694
ex = NULL;
726695

727696
// If this generated function has an opaque closure, cache it for
@@ -778,6 +747,9 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t
778747
if (jl_is_method_instance(kind)) {
779748
jl_method_instance_add_backedge((jl_method_instance_t*)kind, jl_nothing, ci);
780749
}
750+
else if (jl_is_binding(kind)) {
751+
jl_add_binding_backedge((jl_binding_t*)kind, (jl_value_t*)ci);
752+
}
781753
else if (jl_is_mtable(kind)) {
782754
assert(i < l);
783755
ex = data[i++];

src/module.c

+15-10
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,20 @@ void jl_invalidate_binding_refs(jl_globalref_t *ref, jl_binding_partition_t *inv
10991099
JL_GC_POP();
11001100
}
11011101

1102+
JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge)
1103+
{
1104+
if (!b->backedges) {
1105+
b->backedges = jl_alloc_vec_any(0);
1106+
jl_gc_wb(b, b->backedges);
1107+
} else if (jl_array_len(b->backedges) > 0 &&
1108+
jl_array_ptr_ref(b->backedges, jl_array_len(b->backedges)-1) == edge) {
1109+
// Optimization: Deduplicate repeated insertion of the same edge (e.g. during
1110+
// definition of a method that contains many references to the same global)
1111+
return;
1112+
}
1113+
jl_array_ptr_1d_push(b->backedges, edge);
1114+
}
1115+
11021116
// Called for all GlobalRefs found in lowered code. Adds backedges for cross-module
11031117
// GlobalRefs.
11041118
JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge)
@@ -1114,16 +1128,7 @@ JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t
11141128
jl_binding_t *b = gr->binding;
11151129
if (!b)
11161130
b = jl_get_module_binding(gr->mod, gr->name, 1);
1117-
if (!b->backedges) {
1118-
b->backedges = jl_alloc_vec_any(0);
1119-
jl_gc_wb(b, b->backedges);
1120-
} else if (jl_array_len(b->backedges) > 0 &&
1121-
jl_array_ptr_ref(b->backedges, jl_array_len(b->backedges)-1) == edge) {
1122-
// Optimization: Deduplicate repeated insertion of the same edge (e.g. during
1123-
// definition of a method that contains many references to the same global)
1124-
return;
1125-
}
1126-
jl_array_ptr_1d_push(b->backedges, edge);
1131+
jl_add_binding_backedge(b, edge);
11271132
}
11281133

11291134
JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr)

test/rebinding.jl

+7
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ module Rebinding
5555
@test f_return_delete_me_indirect() == 3
5656
Base.delete_binding(@__MODULE__, :delete_me)
5757
@test_throws UndefVarError f_return_delete_me_indirect()
58+
59+
# + via generated function
60+
const delete_me = 4
61+
@generated f_generated_return_delete_me() = return :(delete_me)
62+
@test f_generated_return_delete_me() == 4
63+
Base.delete_binding(@__MODULE__, :delete_me)
64+
@test_throws UndefVarError f_generated_return_delete_me()
5865
end
5966

6067
module RebindingPrecompile

test/staged.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -381,15 +381,15 @@ let
381381
@test length(ir.cfg.blocks) == 1
382382
end
383383

384-
function generate_lambda_ex(world::UInt, source::LineNumberNode,
384+
function generate_lambda_ex(world::UInt, source::Method,
385385
argnames, spnames, @nospecialize body)
386386
stub = Core.GeneratedFunctionStub(identity, Core.svec(argnames...), Core.svec(spnames...))
387387
return stub(world, source, body)
388388
end
389389

390390
# Test that `Core.CachedGenerator` works as expected
391391
struct Generator54916 <: Core.CachedGenerator end
392-
function (::Generator54916)(world::UInt, source::LineNumberNode, args...)
392+
function (::Generator54916)(world::UInt, source::Method, args...)
393393
return generate_lambda_ex(world, source,
394394
(:doit54916, :func, :arg), (), :(func(arg)))
395395
end
@@ -432,7 +432,7 @@ function overdubbee54341(a, b)
432432
a + b
433433
end
434434
const overdubee_codeinfo54341 = code_lowered(overdubbee54341, Tuple{Any, Any})[1]
435-
function overdub_generator54341(world::UInt, source::LineNumberNode, selftype, fargtypes)
435+
function overdub_generator54341(world::UInt, source::Method, selftype, fargtypes)
436436
if length(fargtypes) != 2
437437
return generate_lambda_ex(world, source,
438438
(:overdub54341, :args), (), :(error("Wrong number of arguments")))

0 commit comments

Comments
 (0)