Skip to content

Commit b5adfae

Browse files
authored
Merge pull request #59 from siboehm/siboehm/rf
Implement random forest
2 parents 47159bc + 371e407 commit b5adfae

File tree

6 files changed

+90
-4
lines changed

6 files changed

+90
-4
lines changed

.github/workflows/ci.yml

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ on:
33
push:
44
branches: [ master ]
55
pull_request:
6+
branches: ['*']
67
workflow_dispatch:
78

89
jobs:

lleaves/compiler/ast/nodes.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class Forest:
2929
objective_func: str
3030
objective_func_config: str
3131
raw_score: bool = False
32+
# average output over trees instead of just accumulating
33+
average_output: bool = False
3234

3335
@property
3436
def n_args(self):

lleaves/compiler/ast/parser.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def parse_to_ast(model_path):
101101
objective = scanned_model["general_info"]["objective"]
102102
objective_func = objective[0]
103103
objective_func_config = objective[1] if len(objective) > 1 else None
104+
average_output = "average_output" in scanned_model["general_info"]
104105
features = [
105106
Feature(is_categorical_feature(x))
106107
for x in scanned_model["general_info"]["feature_infos"]
@@ -114,7 +115,14 @@ def parse_to_ast(model_path):
114115
)
115116
]
116117
assert len(trees) % n_classes == 0, "Ill formed model file"
117-
return Forest(trees, features, n_classes, objective_func, objective_func_config)
118+
return Forest(
119+
trees,
120+
features,
121+
n_classes,
122+
objective_func,
123+
objective_func_config,
124+
average_output=average_output,
125+
)
118126

119127

120128
def is_categorical_feature(feature_info: str):

lleaves/compiler/ast/scanner.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(self, type: type, is_list=False, null_ok=False):
7676
"version": ScannedValue(str),
7777
"feature_infos": ScannedValue(str, True),
7878
"objective": ScannedValue(str, True),
79+
"average_output": ScannedValue(bool, null_ok=True),
7980
}
8081
TREE_SCAN_KEYS = {
8182
"Tree": ScannedValue(int),
@@ -106,7 +107,13 @@ def _scan_block(lines: list, items_to_scan: dict):
106107
if line == "tree":
107108
continue
108109

109-
scanned_key, scanned_value = line.split("=")
110+
line_split = line.split("=")
111+
if len(line_split) == 2:
112+
scanned_key, scanned_value = line.split("=")
113+
else:
114+
assert len(line_split) == 1, f"Unexpected line {line}"
115+
scanned_key, scanned_value = line_split[0], True
116+
110117
target_type = items_to_scan.get(scanned_key)
111118
if target_type is None:
112119
continue

lleaves/compiler/codegen/codegen.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ def _populate_instruction_block(
241241
forest.objective_func,
242242
forest.objective_func_config,
243243
forest.raw_score,
244+
forest.average_output,
245+
len(forest.trees),
244246
)
245247
for result, result_ptr in zip(results, results_ptr):
246248
builder.store(result, result_ptr)
@@ -279,7 +281,13 @@ def _populate_forest_func(forest, root_func, tree_funcs, fblocksize):
279281

280282

281283
def _populate_objective_func_block(
282-
builder, args, objective: str, objective_config: str, raw_score: bool
284+
builder,
285+
args,
286+
objective: str,
287+
objective_config: str,
288+
raw_score: bool,
289+
average_output: bool,
290+
num_trees: int,
283291
):
284292
"""
285293
Takes the objective function specification and generates the code for it into the builder
@@ -290,6 +298,9 @@ def _populate_objective_func_block(
290298
"llvm.copysign", (DOUBLE, DOUBLE), ir.FunctionType(DOUBLE, (DOUBLE, DOUBLE))
291299
)
292300

301+
if average_output:
302+
args[0] = builder.fdiv(args[0], dconst(num_trees))
303+
293304
def _populate_sigmoid(alpha):
294305
if alpha <= 0:
295306
raise ValueError(f"Sigmoid parameter needs to be >0, is {alpha}")

tests/test_tree_output.py

+58-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from hypothesis import given, settings
55
from hypothesis import strategies as st
6-
from sklearn.datasets import make_classification
6+
from sklearn.datasets import make_blobs, make_classification, make_regression
77

88
import lleaves
99

@@ -156,3 +156,60 @@ def test_multiclass_generated(tmpdir):
156156
lgbm.predict(X, n_jobs=2), llvm.predict(X, n_jobs=2), decimal=10
157157
)
158158
assert lgbm.num_model_per_iteration() == llvm.num_model_per_iteration()
159+
160+
161+
def test_random_forest_classifier(tmpdir):
162+
centers = [[-4, -4], [4, 4]]
163+
X, y = make_blobs(n_samples=100, centers=centers, random_state=42)
164+
165+
# rf = random forest (outputs are averaged over all trees)
166+
params = {
167+
"boosting_type": "rf",
168+
"n_estimators": 7,
169+
"bagging_freq": 1,
170+
"bagging_fraction": 0.8,
171+
}
172+
clf = lightgbm.LGBMClassifier(**params).fit(X, y)
173+
model_file = str(tmpdir / "model.txt")
174+
clf.booster_.save_model(model_file)
175+
176+
lgbm = lightgbm.Booster(model_file=model_file)
177+
llvm = lleaves.Model(model_file=model_file)
178+
llvm.compile()
179+
180+
# check predictions equal on the whole dataset
181+
np.testing.assert_almost_equal(
182+
lgbm.predict(X, n_jobs=2), llvm.predict(X, n_jobs=2), decimal=10
183+
)
184+
assert lgbm.num_model_per_iteration() == llvm.num_model_per_iteration()
185+
186+
187+
@pytest.mark.parametrize("num_trees", [34, 35])
188+
def test_random_forest_regressor(tmpdir, num_trees):
189+
n_samples = 1000
190+
X, y = make_regression(n_samples=n_samples, n_features=5, noise=10.0)
191+
192+
params = {
193+
"objective": "regression",
194+
"n_jobs": 1,
195+
"boosting_type": "rf",
196+
"subsample_freq": 1,
197+
"subsample": 0.9,
198+
"colsample_bytree": 0.9,
199+
"num_leaves": 25,
200+
"n_estimators": num_trees,
201+
"min_child_samples": 100,
202+
"verbose": 0,
203+
}
204+
205+
model = lightgbm.LGBMRegressor(**params).fit(X, y)
206+
model_file = str(tmpdir / "model.txt")
207+
model.booster_.save_model(model_file)
208+
209+
lgbm = lightgbm.Booster(model_file=model_file)
210+
llvm = lleaves.Model(model_file=model_file)
211+
llvm.compile()
212+
213+
np.testing.assert_almost_equal(
214+
lgbm.predict(X, n_jobs=2), llvm.predict(X, n_jobs=2), decimal=10
215+
)

0 commit comments

Comments
 (0)