Skip to content

Commit 989c346

Browse files
authored
Merge pull request #21 from fuyw/add_model_name_in_compilation
Add ability to specify root function's name in compiled binary
2 parents cd1a144 + ecbd0ad commit 989c346

File tree

4 files changed

+31
-5
lines changed

4 files changed

+31
-5
lines changed

lleaves/compiler/codegen/codegen.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class LTree:
3939
class_id: int
4040

4141

42-
def gen_forest(forest, module, fblocksize):
42+
def gen_forest(forest, module, fblocksize, froot_func_name):
4343
"""
4444
Populate the passed IR module with code for the forest.
4545
@@ -81,7 +81,7 @@ def gen_forest(forest, module, fblocksize):
8181
root_func = ir.Function(
8282
module,
8383
ir.FunctionType(ir.VoidType(), (DOUBLE_PTR, DOUBLE_PTR, INT, INT)),
84-
name="forest_root",
84+
name=froot_func_name,
8585
)
8686

8787
def make_tree(tree):

lleaves/compiler/tree_compiler.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,18 @@
77
from lleaves.compiler.codegen import gen_forest
88

99

10-
def compile_to_module(file_path, fblocksize=34, finline=True, raw_score=False):
10+
def compile_to_module(
11+
file_path,
12+
fblocksize=34,
13+
finline=True,
14+
raw_score=False,
15+
froot_func_name="forest_root",
16+
):
1117
forest = parse_to_ast(file_path)
1218
forest.raw_score = raw_score
1319

1420
ir = llvmlite.ir.Module(name="forest")
15-
gen_forest(forest, ir, fblocksize)
21+
gen_forest(forest, ir, fblocksize, froot_func_name)
1622

1723
ir.triple = llvm.get_process_triple()
1824
module = llvm.parse_assembly(str(ir))

lleaves/lleaves.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def compile(
8585
fblocksize=34,
8686
fcodemodel="large",
8787
finline=True,
88+
froot_func_name="forest_root",
8889
):
8990
"""
9091
Generate the LLVM IR for this model and compile it to ASM.
@@ -107,6 +108,8 @@ def compile(
107108
very large forests.
108109
:param finline: Whether or not to inline function. Setting this to False will speed-up compilation time
109110
significantly but will slow down prediction.
111+
:param froot_func_name: Name of entry point function in the compiled binary. This is the function to link when
112+
writing a C function wrapper. Defaults to "forest_root".
110113
"""
111114
assert 0 < fblocksize
112115
assert fcodemodel in ("small", "large")
@@ -117,6 +120,7 @@ def compile(
117120
raw_score=raw_score,
118121
fblocksize=fblocksize,
119122
finline=finline,
123+
froot_func_name=froot_func_name,
120124
)
121125
else:
122126
# when loading binary from cache we use a dummy empty module
@@ -128,7 +132,7 @@ def compile(
128132
)
129133

130134
# Drops GIL during call, re-acquires it after
131-
addr = self._execution_engine.get_function_address("forest_root")
135+
addr = self._execution_engine.get_function_address(froot_func_name)
132136
self._c_entry_func = ENTRY_FUNC_TYPE(addr)
133137

134138
self.is_compiled = True

tests/test_compile_flags.py

+16
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,19 @@ def test_no_inline(NYC_data):
7474
llvm_model.predict(NYC_data[:1000], n_jobs=2),
7575
lgbm_model.predict(NYC_data[:1000], n_jobs=2),
7676
)
77+
78+
79+
def test_function_name():
80+
llvm_model = Model(model_file="tests/models/tiniest_single_tree/model.txt")
81+
lgbm_model = Booster(model_file="tests/models/tiniest_single_tree/model.txt")
82+
llvm_model.compile(froot_func_name="tiniest_single_tree_123132_XXX-")
83+
84+
data = [
85+
[1.0] * 3,
86+
[0.0] * 3,
87+
[-1.0] * 3,
88+
]
89+
np.testing.assert_almost_equal(
90+
llvm_model.predict(data, n_jobs=2),
91+
lgbm_model.predict(data, n_jobs=2),
92+
)

0 commit comments

Comments
 (0)