Skip to content

Commit 45d004b

Browse files
authored
Merge pull request #70 from zjzjwang/single_precision
add single precision (float32) mode
2 parents 4d01c6e + 1102962 commit 45d004b

File tree

6 files changed

+168
-62
lines changed

6 files changed

+168
-62
lines changed

lleaves/compiler/codegen/codegen.py

+63-35
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
INT = ir.IntType(bits=32)
1212
LONG = ir.IntType(bits=64)
1313
ZERO_V = ir.Constant(BOOL, 0)
14-
FLOAT_POINTER = ir.PointerType(FLOAT)
14+
FLOAT_PTR = ir.PointerType(FLOAT)
1515
DOUBLE_PTR = ir.PointerType(DOUBLE)
1616

1717

@@ -33,6 +33,14 @@ def dconst(value):
3333
return ir.Constant(DOUBLE, value)
3434

3535

36+
def get_fdtype_const(value, use_fp64):
37+
return dconst(value) if use_fp64 else fconst(value)
38+
39+
40+
def get_fdtype(use_fp64):
41+
return DOUBLE if use_fp64 else FLOAT
42+
43+
3644
@dataclass
3745
class LTree:
3846
"""Class for the LLVM function of a tree paired with relevant non-LLVM context"""
@@ -41,7 +49,7 @@ class LTree:
4149
class_id: int
4250

4351

44-
def gen_forest(forest, module, fblocksize, froot_func_name):
52+
def gen_forest(forest, module, fblocksize, froot_func_name, use_fp64):
4553
"""
4654
Populate the passed IR module with code for the forest.
4755
@@ -80,20 +88,23 @@ def gen_forest(forest, module, fblocksize, froot_func_name):
8088
"""
8189

8290
# entry function called from Python
91+
DTYPE_PTR = DOUBLE_PTR if use_fp64 else FLOAT_PTR
8392
root_func = ir.Function(
8493
module,
85-
ir.FunctionType(ir.VoidType(), (DOUBLE_PTR, DOUBLE_PTR, INT, INT)),
94+
ir.FunctionType(ir.VoidType(), (DTYPE_PTR, DTYPE_PTR, INT, INT)),
8695
name=froot_func_name,
8796
)
8897

8998
def make_tree(tree):
9099
# declare the function for this tree
91-
func_dtypes = (INT_CAT if f.is_categorical else DOUBLE for f in tree.features)
92-
scalar_func_t = ir.FunctionType(DOUBLE, func_dtypes)
100+
func_dtypes = (
101+
INT_CAT if f.is_categorical else get_fdtype(use_fp64) for f in tree.features
102+
)
103+
scalar_func_t = ir.FunctionType(get_fdtype(use_fp64), func_dtypes)
93104
tree_func = ir.Function(module, scalar_func_t, name=str(tree))
94105
tree_func.linkage = "private"
95106
# populate function with IR
96-
gen_tree(tree, tree_func)
107+
gen_tree(tree, tree_func, use_fp64)
97108
return LTree(llvm_function=tree_func, class_id=tree.class_id)
98109

99110
tree_funcs = [make_tree(tree) for tree in forest.trees]
@@ -102,30 +113,30 @@ def make_tree(tree):
102113
# better locality by running trees for each class together
103114
tree_funcs.sort(key=lambda t: t.class_id)
104115

105-
_populate_forest_func(forest, root_func, tree_funcs, fblocksize)
116+
_populate_forest_func(forest, root_func, tree_funcs, fblocksize, use_fp64)
106117

107118

108-
def gen_tree(tree, tree_func):
119+
def gen_tree(tree, tree_func, use_fp64):
109120
"""generate code for tree given the function, recursing into nodes"""
110121
node_block = tree_func.append_basic_block(name=str(tree.root_node))
111-
gen_node(tree_func, node_block, tree.root_node)
122+
gen_node(tree_func, node_block, tree.root_node, use_fp64)
112123

113124

114-
def gen_node(func, node_block, node):
125+
def gen_node(func, node_block, node, use_fp64):
115126
"""generate code for node, recursing into children"""
116127
if node.is_leaf:
117-
_gen_leaf_node(node_block, node)
128+
_gen_leaf_node(node_block, node, use_fp64)
118129
else:
119-
_gen_decision_node(func, node_block, node)
130+
_gen_decision_node(func, node_block, node, use_fp64)
120131

121132

122-
def _gen_leaf_node(node_block, leaf):
133+
def _gen_leaf_node(node_block, leaf, use_fp64):
123134
"""populate block with leaf's return value"""
124135
builder = ir.IRBuilder(node_block)
125-
builder.ret(dconst(leaf.value))
136+
builder.ret(get_fdtype_const(leaf.value, use_fp64))
126137

127138

128-
def _gen_decision_node(func, node_block, node):
139+
def _gen_decision_node(func, node_block, node, use_fp64):
129140
"""generate code for decision node, recursing into children"""
130141
builder = ir.IRBuilder(node_block)
131142

@@ -151,20 +162,24 @@ def _gen_decision_node(func, node_block, node):
151162
)
152163
builder = bitset_builder
153164
else:
154-
comp = _populate_numerical_node_block(func, builder, node)
165+
comp = _populate_numerical_node_block(func, builder, node, use_fp64)
155166

156167
# finalize this node's block with a terminal statement
157168
if is_fused_double_leaf_node:
158-
ret = builder.select(comp, dconst(node.left.value), dconst(node.right.value))
169+
ret = builder.select(
170+
comp,
171+
get_fdtype_const(node.left.value, use_fp64),
172+
get_fdtype_const(node.right.value, use_fp64),
173+
)
159174
builder.ret(ret)
160175
else:
161176
builder.cbranch(comp, left_block, right_block)
162177

163178
# populate generated child blocks
164179
if left_block:
165-
gen_node(func, left_block, node.left)
180+
gen_node(func, left_block, node.left, use_fp64)
166181
if right_block:
167-
gen_node(func, right_block, node.right)
182+
gen_node(func, right_block, node.right, use_fp64)
168183

169184

170185
def _populate_instruction_block(
@@ -175,6 +190,7 @@ def _populate_instruction_block(
175190
setup_block,
176191
next_block,
177192
eval_obj_func,
193+
use_fp64,
178194
):
179195
"""Generates an instruction_block: loops over all input data and evaluates its chunk of tree_funcs."""
180196
data_arr, out_arr, start_index, end_index = root_func.args
@@ -211,14 +227,14 @@ def _populate_instruction_block(
211227
el = builder.load(ptr)
212228
if feature.is_categorical:
213229
# first, check if the value is NaN
214-
is_nan = builder.fcmp_ordered("uno", el, dconst(0.0))
230+
is_nan = builder.fcmp_ordered("uno", el, get_fdtype_const(0.0, use_fp64))
215231
# if it is, return smallest possible int (will always go right), else cast to int
216232
el = builder.select(is_nan, iconst(-(2**31)), builder.fptosi(el, INT_CAT))
217233
args.append(el)
218234
else:
219235
args.append(el)
220236
# iterate over each tree, sum up results
221-
results = [dconst(0.0) for _ in range(forest.n_classes)]
237+
results = [get_fdtype_const(0.0, use_fp64) for _ in range(forest.n_classes)]
222238
for func in tree_funcs:
223239
tree_res = builder.call(func.llvm_function, args)
224240
results[func.class_id] = builder.fadd(tree_res, results[func.class_id])
@@ -243,6 +259,7 @@ def _populate_instruction_block(
243259
forest.raw_score,
244260
forest.average_output,
245261
len(forest.trees),
262+
use_fp64,
246263
)
247264
for result, result_ptr in zip(results, results_ptr):
248265
builder.store(result, result_ptr)
@@ -252,7 +269,7 @@ def _populate_instruction_block(
252269
# -- END CORE LOOP BLOCK
253270

254271

255-
def _populate_forest_func(forest, root_func, tree_funcs, fblocksize):
272+
def _populate_forest_func(forest, root_func, tree_funcs, fblocksize, use_fp64):
256273
"""Populate root function IR for forest"""
257274

258275
assert fblocksize > 0
@@ -277,6 +294,7 @@ def _populate_forest_func(forest, root_func, tree_funcs, fblocksize):
277294
setup_block,
278295
next_block,
279296
eval_objective_func,
297+
use_fp64,
280298
)
281299

282300

@@ -288,28 +306,30 @@ def _populate_objective_func_block(
288306
raw_score: bool,
289307
average_output: bool,
290308
num_trees: int,
309+
use_fp64: bool,
291310
):
292311
"""
293312
Takes the objective function specification and generates the code for it into the builder
294313
"""
295-
llvm_exp = builder.module.declare_intrinsic("llvm.exp", (DOUBLE,))
296-
llvm_log = builder.module.declare_intrinsic("llvm.log", (DOUBLE,))
314+
DTYPE = get_fdtype(use_fp64)
315+
llvm_exp = builder.module.declare_intrinsic("llvm.exp", (DTYPE,))
316+
llvm_log = builder.module.declare_intrinsic("llvm.log", (DTYPE,))
297317
llvm_copysign = builder.module.declare_intrinsic(
298-
"llvm.copysign", (DOUBLE, DOUBLE), ir.FunctionType(DOUBLE, (DOUBLE, DOUBLE))
318+
"llvm.copysign", (DTYPE, DTYPE), ir.FunctionType(DTYPE, (DTYPE, DTYPE))
299319
)
300320

301321
if average_output:
302-
args[0] = builder.fdiv(args[0], dconst(num_trees))
322+
args[0] = builder.fdiv(args[0], get_fdtype_const(num_trees, use_fp64))
303323

304324
def _populate_sigmoid(alpha):
305325
if alpha <= 0:
306326
raise ValueError(f"Sigmoid parameter needs to be >0, is {alpha}")
307327

308328
# 1 / (1 + exp(- alpha * x))
309-
inner = builder.fmul(dconst(-alpha), args[0])
329+
inner = builder.fmul(get_fdtype_const(-alpha, use_fp64), args[0])
310330
exp = builder.call(llvm_exp, [inner])
311-
denom = builder.fadd(dconst(1.0), exp)
312-
return builder.fdiv(dconst(1.0), denom)
331+
denom = builder.fadd(get_fdtype_const(1.0, use_fp64), exp)
332+
return builder.fdiv(get_fdtype_const(1.0, use_fp64), denom)
313333

314334
# raw score means we don't need to add the objective function
315335
if raw_score:
@@ -324,7 +344,10 @@ def _populate_sigmoid(alpha):
324344
# naive implementation which will be numerically unstable for small x.
325345
# should be changed to log1p
326346
exp = builder.call(llvm_exp, [args[0]])
327-
result = builder.call(llvm_log, [builder.fadd(dconst(1.0), exp)])
347+
result = builder.call(
348+
llvm_log, [builder.fadd(get_fdtype_const(1.0, use_fp64), exp)]
349+
)
350+
328351
elif objective in ("poisson", "gamma", "tweedie"):
329352
result = builder.call(llvm_exp, [args[0]])
330353
elif objective in (
@@ -347,7 +370,7 @@ def _populate_sigmoid(alpha):
347370
# TODO Might profit from vectorization, needs testing
348371
result = [builder.call(llvm_exp, [arg]) for arg in args]
349372

350-
denominator = dconst(0.0)
373+
denominator = get_fdtype_const(0.0, use_fp64)
351374
for r in result:
352375
denominator = builder.fadd(r, denominator)
353376

@@ -391,11 +414,12 @@ def _populate_categorical_node_block(
391414
return comp
392415

393416

394-
def _populate_numerical_node_block(func, builder, node):
417+
def _populate_numerical_node_block(func, builder, node, use_fp64):
395418
"""populate block with IR for numerical node"""
396419
val = func.args[node.split_feature]
397420

398-
thresh = ir.Constant(DOUBLE, node.threshold)
421+
DTYPE = get_fdtype(use_fp64)
422+
thresh = ir.Constant(DTYPE, node.threshold)
399423
missing_t = node.decision_type.missing_type
400424

401425
# If missingType != MNaN, LightGBM treats NaNs values as if they were 0.0.
@@ -417,7 +441,9 @@ def _populate_numerical_node_block(func, builder, node):
417441
# unordered cmp: we'll get True (and go left) if any arg is qNaN
418442
comp = builder.fcmp_unordered("<=", val, thresh)
419443
else:
420-
is_missing = builder.fcmp_unordered("==", val, fconst(0.0))
444+
is_missing = builder.fcmp_unordered(
445+
"==", val, get_fdtype_const(0.0, use_fp64)
446+
)
421447
less_eq = builder.fcmp_unordered("<=", val, thresh)
422448
comp = builder.or_(is_missing, less_eq)
423449
else:
@@ -427,7 +453,9 @@ def _populate_numerical_node_block(func, builder, node):
427453
# ordered cmp: we'll get False (and go right) if any arg is qNaN
428454
comp = builder.fcmp_ordered("<=", val, thresh)
429455
else:
430-
is_missing = builder.fcmp_unordered("==", val, fconst(0.0))
456+
is_missing = builder.fcmp_unordered(
457+
"==", val, get_fdtype_const(0.0, use_fp64)
458+
)
431459
greater = builder.fcmp_ordered(">", val, thresh)
432460
comp = builder.not_(builder.or_(is_missing, greater))
433461
return comp

lleaves/compiler/tree_compiler.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@ def compile_to_module(
1313
finline=True,
1414
raw_score=False,
1515
froot_func_name="forest_root",
16+
use_fp64=True,
1617
):
1718
forest = parse_to_ast(file_path)
1819
forest.raw_score = raw_score
1920

2021
ir = llvmlite.ir.Module(name="forest")
21-
gen_forest(forest, ir, fblocksize, froot_func_name)
22+
gen_forest(forest, ir, fblocksize, froot_func_name, use_fp64)
2223

2324
ir.triple = llvm.get_process_triple()
2425
module = llvm.parse_assembly(str(ir))

lleaves/data_processing.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import os
3-
from ctypes import POINTER, c_double
3+
from ctypes import POINTER, c_double, c_float
44
from typing import List, Optional
55

66
import numpy as np
@@ -94,16 +94,22 @@ def data_to_ndarray(data, pd_traintime_categories: Optional[List[List]] = None):
9494
return data
9595

9696

97-
def ndarray_to_ptr(data: np.ndarray):
97+
def ndarray_to_ptr(data: np.ndarray, use_fp64: bool = True):
9898
"""
99-
Takes a 2D numpy array, converts to float64 if necessary and returns a pointer
99+
Takes a 2D numpy array, converts it to either float64 or float32 depending on the `use_fp64` flag,
100+
and returns a pointer to the data.
100101
101102
:param data: 2D numpy array. Copying is avoided if possible.
102-
:return: pointer to 1D array of dtype float64.
103+
:param use_fp64: Bool. Casting to float64 if True, otherwise float32.
104+
:return: pointer to 1D array of type float64 if `use_fp64` is True, otherwise float32.
103105
"""
104106
# ravel makes sure we get a contiguous array in memory and not some strided View
105-
data = data.astype(np.float64, copy=False, casting="same_kind").ravel()
106-
ptr = data.ctypes.data_as(POINTER(c_double))
107+
data = data.astype(
108+
np.float64 if use_fp64 else np.float32,
109+
copy=False,
110+
casting="same_kind",
111+
).ravel()
112+
ptr = data.ctypes.data_as(POINTER(c_double if use_fp64 else c_float))
107113
return ptr
108114

109115

0 commit comments

Comments
 (0)