Skip to content

Commit fa14013

Browse files
committed
Merge branch 'fix_overflow_largepredict'
2 parents 3b965a7 + 9784625 commit fa14013

File tree

10 files changed

+91
-34
lines changed

10 files changed

+91
-34
lines changed

.github/ci.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ python -m pip install --no-use-pep517 --no-deps --disable-pip-version-check -e .
88
pytest -v tests
99

1010
# Check documentation build only in one job, also do releases
11-
if [ "${PYTHON_VERSION}" = "3.6" ]; then
11+
if [ "${PYTHON_VERSION}" = "3.7" ]; then
1212
pushd docs
1313
make html
1414
popd

.github/workflows/ci.yml

+1-4
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,11 @@ jobs:
3434
with:
3535
path: ./.hypothesis
3636
key: hypothesisDB ${{ matrix.PYTHON_VERSION }}
37-
- if: matrix.PYTHON_VERSION == '3.6'
38-
shell: bash -x -l {0}
39-
run: pip install dataclasses
4037
- name: Run the unittests
4138
shell: bash -x -l {0}
4239
run: ./.github/ci.sh ${{ matrix.PYTHON_VERSION }}
4340
- name: Publish a Python distribution to PyPI
44-
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') && matrix.PYTHON_VERSION == '3.6'
41+
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') && matrix.PYTHON_VERSION == '3.7'
4542
uses: pypa/gh-action-pypi-publish@v1.4.2
4643
with:
4744
user: __token__

environment.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ channels:
44
- nodefaults
55
dependencies:
66
# runtime deps
7-
- python>=3.6
7+
- python>=3.7
88
- llvmlite>=0.36
99
- numpy
1010
# testing

lleaves/compiler/codegen/codegen.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
FLOAT = ir.FloatType()
1010
INT_CAT = ir.IntType(bits=32)
1111
INT = ir.IntType(bits=32)
12+
LONG = ir.IntType(bits=64)
1213
ZERO_V = ir.Constant(BOOL, 0)
1314
FLOAT_POINTER = ir.PointerType(FLOAT)
1415
DOUBLE_PTR = ir.PointerType(DOUBLE)
@@ -18,6 +19,10 @@ def iconst(value):
1819
return ir.Constant(INT, value)
1920

2021

22+
def lconst(value):
23+
return ir.Constant(LONG, value)
24+
25+
2126
def fconst(value):
2227
return ir.Constant(FLOAT, value)
2328

@@ -168,7 +173,9 @@ def _populate_instruction_block(
168173

169174
# -- SETUP BLOCK
170175
builder = ir.IRBuilder(setup_block)
171-
loop_iter = builder.alloca(INT, 1, "loop-idx")
176+
start_index = builder.zext(start_index, LONG)
177+
end_index = builder.zext(end_index, LONG)
178+
loop_iter = builder.alloca(LONG, 1, "loop-idx")
172179
builder.store(start_index, loop_iter)
173180
condition_block = root_func.append_basic_block("loop-condition")
174181
builder.branch(condition_block)
@@ -187,9 +194,9 @@ def _populate_instruction_block(
187194
args = []
188195
loop_iter_reg = builder.load(loop_iter)
189196

190-
n_args = ir.Constant(INT, forest.n_args)
197+
n_args = ir.Constant(LONG, forest.n_args)
191198
iter_mul_nargs = builder.mul(loop_iter_reg, n_args)
192-
idx = (builder.add(iter_mul_nargs, iconst(i)) for i in range(forest.n_args))
199+
idx = (builder.add(iter_mul_nargs, lconst(i)) for i in range(forest.n_args))
193200
raw_ptrs = [builder.gep(root_func.args[0], (c,)) for c in idx]
194201
# cast the categorical inputs to integer
195202
for feature, ptr in zip(forest.features, raw_ptrs):
@@ -203,9 +210,9 @@ def _populate_instruction_block(
203210
for func in tree_funcs:
204211
tree_res = builder.call(func.llvm_function, args)
205212
results[func.class_id] = builder.fadd(tree_res, results[func.class_id])
206-
res_idx = builder.mul(iconst(forest.n_classes), loop_iter_reg)
213+
res_idx = builder.mul(lconst(forest.n_classes), loop_iter_reg)
207214
results_ptr = [
208-
builder.gep(out_arr, (builder.add(res_idx, iconst(class_idx)),))
215+
builder.gep(out_arr, (builder.add(res_idx, lconst(class_idx)),))
209216
for class_idx in range(forest.n_classes)
210217
]
211218

@@ -224,8 +231,7 @@ def _populate_instruction_block(
224231
for result, result_ptr in zip(results, results_ptr):
225232
builder.store(result, result_ptr)
226233

227-
tmpp1 = builder.add(loop_iter_reg, iconst(1))
228-
builder.store(tmpp1, loop_iter)
234+
builder.store(builder.add(loop_iter_reg, lconst(1)), loop_iter)
229235
builder.branch(condition_block)
230236
# -- END CORE LOOP BLOCK
231237

lleaves/data_processing.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import List, Optional
55

66
import numpy as np
7+
import pandas as pd
78

89
try:
910
from pandas import DataFrame as pd_DataFrame
@@ -15,7 +16,7 @@ class pd_DataFrame:
1516
pass
1617

1718

18-
def _dataframe_to_ndarray(data, pd_traintime_categories: List[List]):
19+
def _dataframe_to_ndarray(data: pd.DataFrame, pd_traintime_categories: List[List]):
1920
"""
2021
Converts the given dataframe into a 2D numpy array and converts categorical columns to float.
2122
@@ -94,7 +95,7 @@ def data_to_ndarray(data, pd_traintime_categories: Optional[List[List]] = None):
9495
return data
9596

9697

97-
def ndarray_to_ptr(data):
98+
def ndarray_to_ptr(data: np.ndarray):
9899
"""
99100
Takes a 2D numpy array, converts to float64 if necessary and returns a pointer
100101

lleaves/lleaves.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import concurrent.futures
22
import math
33
import os
4-
from ctypes import CFUNCTYPE, POINTER, c_double, c_int
4+
from ctypes import CFUNCTYPE, POINTER, c_double, c_int32
55
from pathlib import Path
66

77
import llvmlite.binding
@@ -20,8 +20,8 @@
2020
None, # return void
2121
POINTER(c_double), # pointer to data array
2222
POINTER(c_double), # pointer to results array
23-
c_int, # start index
24-
c_int, # end index
23+
c_int32, # start index
24+
c_int32, # end index
2525
)
2626

2727

@@ -89,12 +89,10 @@ def compile(
8989
"""
9090
Generate the LLVM IR for this model and compile it to ASM.
9191
92-
For most users tweaking the compilation flags (fcodemodel, fblocksize) will be unnecessary as the default
93-
configuration is already very fast.
92+
For most users tweaking the compilation flags (fcodemodel, fblocksize, finline) will be unnecessary
93+
as the default configuration is already very fast.
9494
Modifying the flags is useful only if you're trying to squeeze out the last few percent of performance.
9595
96-
The compile() method is generally not thread-safe.
97-
9896
:param cache: Path to a cache file. If this path doesn't exist, binary will be dumped at path after compilation.
9997
If path exists, binary will be loaded and compilation skipped.
10098
No effort is made to check staleness / consistency.
@@ -160,6 +158,12 @@ def predict(self, data, n_jobs=os.cpu_count()):
160158
raise ValueError(
161159
f"Data must be of dimension (N, {self.num_feature()}), is {data.shape}."
162160
)
161+
# protect against `ctypes.c_int32` silently overflowing and causing SIGSEGV
162+
if n_predictions >= 2 ** 31 - 1:
163+
raise ValueError(
164+
"Prediction is not supported for datasets with >=2^31-1 rows. "
165+
"Split the dataset into smaller chunks first."
166+
)
163167

164168
# setup input data and predictions array
165169
ptr_data = ndarray_to_ptr(data)

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@
2424
description="LLVM-based compiler for LightGBM models",
2525
long_description=long_description,
2626
long_description_content_type="text/markdown",
27-
python_requires=">=3.6",
28-
install_requires=["llvmlite>=0.36", "numpy", "dataclasses; python_version < '3.7'"],
27+
python_requires=">=3.7",
28+
install_requires=["llvmlite>=0.36", "numpy"],
2929
)

tests/conftest.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
from lightgbm import Booster
3+
4+
from lleaves import Model
5+
6+
7+
@pytest.fixture(scope="session")
8+
def NYC_llvm():
9+
llvm_model = Model(model_file="tests/models/NYC_taxi/model.txt")
10+
llvm_model.compile()
11+
return llvm_model
12+
13+
14+
@pytest.fixture(scope="session")
15+
def NYC_lgbm():
16+
return Booster(model_file="tests/models/NYC_taxi/model.txt")
17+
18+
19+
@pytest.fixture(scope="session")
20+
def mtpl2_llvm():
21+
llvm_model = Model(model_file="tests/models/mtpl2/model.txt")
22+
llvm_model.compile()
23+
return llvm_model

tests/test_dataprocessing.py

+19
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import numpy as np
44
import pandas as pd
55
import pytest
6+
from lightgbm import Booster
67

8+
from lleaves import Model
79
from lleaves.data_processing import (
810
data_to_ndarray,
911
extract_model_global_features,
@@ -87,3 +89,20 @@ def test_no_data_modification():
8789
pred = pd.DataFrame(data).astype("category")
8890
ndarray_to_ptr(data_to_ndarray(pred, data))
8991
pd.testing.assert_frame_equal(pred, orig)
92+
93+
94+
def test_sliced_arrays():
95+
# predictions should be correct when passed a sliced array
96+
llvm_model = Model(model_file="tests/models/single_tree/model.txt")
97+
llvm_model.compile()
98+
lgbm_model = Booster(model_file="tests/models/single_tree/model.txt")
99+
100+
n_feature = lgbm_model.num_feature()
101+
data = np.array(list(range(-5 * n_feature, 5 * n_feature)), dtype=np.float64)
102+
data = data.reshape((5, 2 * n_feature))
103+
sliced = data[:, ::2]
104+
assert not sliced.flags.c_contiguous
105+
np.testing.assert_almost_equal(
106+
llvm_model.predict(sliced, n_jobs=4), lgbm_model.predict(sliced), decimal=13
107+
)
108+
return

tests/test_parallel.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,37 @@
11
from ctypes import POINTER, c_double
22

33
import numpy as np
4-
from lightgbm import Booster
54

6-
from lleaves import Model
75

6+
def test_parallel_edgecases(NYC_llvm, NYC_lgbm):
7+
# single row, multiple threads
8+
data = np.array(1 * [NYC_lgbm.num_feature() * [1.0]], dtype=np.float64)
9+
np.testing.assert_almost_equal(
10+
NYC_llvm.predict(data, n_jobs=4), NYC_lgbm.predict(data), decimal=14
11+
)
12+
13+
# last thread has only one prediction (batchsize is ceil(19/7)=3)
14+
data = np.array(19 * [NYC_lgbm.num_feature() * [1.0]], dtype=np.float64)
15+
np.testing.assert_almost_equal(
16+
NYC_llvm.predict(data, n_jobs=7), NYC_lgbm.predict(data), decimal=14
17+
)
818

9-
def test_parallel_iteration():
10-
llvm_model = Model(model_file="tests/models/NYC_taxi/model.txt")
11-
lgbm_model = Booster(model_file="tests/models/NYC_taxi/model.txt")
12-
llvm_model.compile()
1319

14-
data = np.array(4 * [5 * [1.0]], dtype=np.float64)
20+
def test_parallel_iteration(NYC_llvm, NYC_lgbm):
21+
data = np.array(4 * [NYC_lgbm.num_feature() * [1.0]], dtype=np.float64)
1522
data_flat = np.array(data.reshape(data.size), dtype=np.float64)
1623
np.testing.assert_almost_equal(
17-
llvm_model.predict(data, n_jobs=4), lgbm_model.predict(data), decimal=14
24+
NYC_llvm.predict(data, n_jobs=4), NYC_lgbm.predict(data), decimal=14
1825
)
1926

2027
ptr_data = data_flat.ctypes.data_as(POINTER(c_double))
2128
preds = np.zeros(4, dtype=np.float64)
2229
ptr_preds = preds.ctypes.data_as(POINTER(c_double))
2330

24-
llvm_model._c_entry_func(ptr_data, ptr_preds, 2, 4)
31+
NYC_llvm._c_entry_func(ptr_data, ptr_preds, 2, 4)
2532
preds_l = list(preds)
2633
assert preds_l[0] == 0.0 and preds_l[1] == 0.0
2734
assert preds_l[2] != 0.0 and preds_l[3] != 0.0
28-
llvm_model._c_entry_func(ptr_data, ptr_preds, 0, 2)
35+
NYC_llvm._c_entry_func(ptr_data, ptr_preds, 0, 2)
2936
preds_l = list(preds)
3037
assert preds_l[0] != 0.0 and preds_l[1] != 0.0

0 commit comments

Comments
 (0)