|
3 | 3 | import pytest
|
4 | 4 | from hypothesis import given, settings
|
5 | 5 | from hypothesis import strategies as st
|
6 |
| -from sklearn.datasets import make_classification |
| 6 | +from sklearn.datasets import make_blobs, make_classification, make_regression |
7 | 7 |
|
8 | 8 | import lleaves
|
9 | 9 |
|
@@ -156,3 +156,60 @@ def test_multiclass_generated(tmpdir):
|
156 | 156 | lgbm.predict(X, n_jobs=2), llvm.predict(X, n_jobs=2), decimal=10
|
157 | 157 | )
|
158 | 158 | assert lgbm.num_model_per_iteration() == llvm.num_model_per_iteration()
|
| 159 | + |
| 160 | + |
| 161 | +def test_random_forest_classifier(tmpdir): |
| 162 | + centers = [[-4, -4], [4, 4]] |
| 163 | + X, y = make_blobs(n_samples=100, centers=centers, random_state=42) |
| 164 | + |
| 165 | + # rf = random forest (outputs are averaged over all trees) |
| 166 | + params = { |
| 167 | + "boosting_type": "rf", |
| 168 | + "n_estimators": 7, |
| 169 | + "bagging_freq": 1, |
| 170 | + "bagging_fraction": 0.8, |
| 171 | + } |
| 172 | + clf = lightgbm.LGBMClassifier(**params).fit(X, y) |
| 173 | + model_file = str(tmpdir / "model.txt") |
| 174 | + clf.booster_.save_model(model_file) |
| 175 | + |
| 176 | + lgbm = lightgbm.Booster(model_file=model_file) |
| 177 | + llvm = lleaves.Model(model_file=model_file) |
| 178 | + llvm.compile() |
| 179 | + |
| 180 | + # check predictions equal on the whole dataset |
| 181 | + np.testing.assert_almost_equal( |
| 182 | + lgbm.predict(X, n_jobs=2), llvm.predict(X, n_jobs=2), decimal=10 |
| 183 | + ) |
| 184 | + assert lgbm.num_model_per_iteration() == llvm.num_model_per_iteration() |
| 185 | + |
| 186 | + |
| 187 | +@pytest.mark.parametrize("num_trees", [34, 35]) |
| 188 | +def test_random_forest_regressor(tmpdir, num_trees): |
| 189 | + n_samples = 1000 |
| 190 | + X, y = make_regression(n_samples=n_samples, n_features=5, noise=10.0) |
| 191 | + |
| 192 | + params = { |
| 193 | + "objective": "regression", |
| 194 | + "n_jobs": 1, |
| 195 | + "boosting_type": "rf", |
| 196 | + "subsample_freq": 1, |
| 197 | + "subsample": 0.9, |
| 198 | + "colsample_bytree": 0.9, |
| 199 | + "num_leaves": 25, |
| 200 | + "n_estimators": num_trees, |
| 201 | + "min_child_samples": 100, |
| 202 | + "verbose": 0, |
| 203 | + } |
| 204 | + |
| 205 | + model = lightgbm.LGBMRegressor(**params).fit(X, y) |
| 206 | + model_file = str(tmpdir / "model.txt") |
| 207 | + model.booster_.save_model(model_file) |
| 208 | + |
| 209 | + lgbm = lightgbm.Booster(model_file=model_file) |
| 210 | + llvm = lleaves.Model(model_file=model_file) |
| 211 | + llvm.compile() |
| 212 | + |
| 213 | + np.testing.assert_almost_equal( |
| 214 | + lgbm.predict(X, n_jobs=2), llvm.predict(X, n_jobs=2), decimal=10 |
| 215 | + ) |
0 commit comments