11
11
INT = ir .IntType (bits = 32 )
12
12
LONG = ir .IntType (bits = 64 )
13
13
ZERO_V = ir .Constant (BOOL , 0 )
14
- FLOAT_POINTER = ir .PointerType (FLOAT )
14
+ FLOAT_PTR = ir .PointerType (FLOAT )
15
15
DOUBLE_PTR = ir .PointerType (DOUBLE )
16
16
17
17
@@ -33,6 +33,14 @@ def dconst(value):
33
33
return ir .Constant (DOUBLE , value )
34
34
35
35
36
+ def get_fdtype_const (value , use_fp64 ):
37
+ return dconst (value ) if use_fp64 else fconst (value )
38
+
39
+
40
+ def get_fdtype (use_fp64 ):
41
+ return DOUBLE if use_fp64 else FLOAT
42
+
43
+
36
44
@dataclass
37
45
class LTree :
38
46
"""Class for the LLVM function of a tree paired with relevant non-LLVM context"""
@@ -41,7 +49,7 @@ class LTree:
41
49
class_id : int
42
50
43
51
44
- def gen_forest (forest , module , fblocksize , froot_func_name ):
52
+ def gen_forest (forest , module , fblocksize , froot_func_name , use_fp64 ):
45
53
"""
46
54
Populate the passed IR module with code for the forest.
47
55
@@ -80,20 +88,23 @@ def gen_forest(forest, module, fblocksize, froot_func_name):
80
88
"""
81
89
82
90
# entry function called from Python
91
+ DTYPE_PTR = DOUBLE_PTR if use_fp64 else FLOAT_PTR
83
92
root_func = ir .Function (
84
93
module ,
85
- ir .FunctionType (ir .VoidType (), (DOUBLE_PTR , DOUBLE_PTR , INT , INT )),
94
+ ir .FunctionType (ir .VoidType (), (DTYPE_PTR , DTYPE_PTR , INT , INT )),
86
95
name = froot_func_name ,
87
96
)
88
97
89
98
def make_tree (tree ):
90
99
# declare the function for this tree
91
- func_dtypes = (INT_CAT if f .is_categorical else DOUBLE for f in tree .features )
92
- scalar_func_t = ir .FunctionType (DOUBLE , func_dtypes )
100
+ func_dtypes = (
101
+ INT_CAT if f .is_categorical else get_fdtype (use_fp64 ) for f in tree .features
102
+ )
103
+ scalar_func_t = ir .FunctionType (get_fdtype (use_fp64 ), func_dtypes )
93
104
tree_func = ir .Function (module , scalar_func_t , name = str (tree ))
94
105
tree_func .linkage = "private"
95
106
# populate function with IR
96
- gen_tree (tree , tree_func )
107
+ gen_tree (tree , tree_func , use_fp64 )
97
108
return LTree (llvm_function = tree_func , class_id = tree .class_id )
98
109
99
110
tree_funcs = [make_tree (tree ) for tree in forest .trees ]
@@ -102,30 +113,30 @@ def make_tree(tree):
102
113
# better locality by running trees for each class together
103
114
tree_funcs .sort (key = lambda t : t .class_id )
104
115
105
- _populate_forest_func (forest , root_func , tree_funcs , fblocksize )
116
+ _populate_forest_func (forest , root_func , tree_funcs , fblocksize , use_fp64 )
106
117
107
118
108
- def gen_tree (tree , tree_func ):
119
+ def gen_tree (tree , tree_func , use_fp64 ):
109
120
"""generate code for tree given the function, recursing into nodes"""
110
121
node_block = tree_func .append_basic_block (name = str (tree .root_node ))
111
- gen_node (tree_func , node_block , tree .root_node )
122
+ gen_node (tree_func , node_block , tree .root_node , use_fp64 )
112
123
113
124
114
- def gen_node (func , node_block , node ):
125
+ def gen_node (func , node_block , node , use_fp64 ):
115
126
"""generate code for node, recursing into children"""
116
127
if node .is_leaf :
117
- _gen_leaf_node (node_block , node )
128
+ _gen_leaf_node (node_block , node , use_fp64 )
118
129
else :
119
- _gen_decision_node (func , node_block , node )
130
+ _gen_decision_node (func , node_block , node , use_fp64 )
120
131
121
132
122
- def _gen_leaf_node (node_block , leaf ):
133
+ def _gen_leaf_node (node_block , leaf , use_fp64 ):
123
134
"""populate block with leaf's return value"""
124
135
builder = ir .IRBuilder (node_block )
125
- builder .ret (dconst (leaf .value ))
136
+ builder .ret (get_fdtype_const (leaf .value , use_fp64 ))
126
137
127
138
128
- def _gen_decision_node (func , node_block , node ):
139
+ def _gen_decision_node (func , node_block , node , use_fp64 ):
129
140
"""generate code for decision node, recursing into children"""
130
141
builder = ir .IRBuilder (node_block )
131
142
@@ -151,20 +162,24 @@ def _gen_decision_node(func, node_block, node):
151
162
)
152
163
builder = bitset_builder
153
164
else :
154
- comp = _populate_numerical_node_block (func , builder , node )
165
+ comp = _populate_numerical_node_block (func , builder , node , use_fp64 )
155
166
156
167
# finalize this node's block with a terminal statement
157
168
if is_fused_double_leaf_node :
158
- ret = builder .select (comp , dconst (node .left .value ), dconst (node .right .value ))
169
+ ret = builder .select (
170
+ comp ,
171
+ get_fdtype_const (node .left .value , use_fp64 ),
172
+ get_fdtype_const (node .right .value , use_fp64 ),
173
+ )
159
174
builder .ret (ret )
160
175
else :
161
176
builder .cbranch (comp , left_block , right_block )
162
177
163
178
# populate generated child blocks
164
179
if left_block :
165
- gen_node (func , left_block , node .left )
180
+ gen_node (func , left_block , node .left , use_fp64 )
166
181
if right_block :
167
- gen_node (func , right_block , node .right )
182
+ gen_node (func , right_block , node .right , use_fp64 )
168
183
169
184
170
185
def _populate_instruction_block (
@@ -175,6 +190,7 @@ def _populate_instruction_block(
175
190
setup_block ,
176
191
next_block ,
177
192
eval_obj_func ,
193
+ use_fp64 ,
178
194
):
179
195
"""Generates an instruction_block: loops over all input data and evaluates its chunk of tree_funcs."""
180
196
data_arr , out_arr , start_index , end_index = root_func .args
@@ -211,14 +227,14 @@ def _populate_instruction_block(
211
227
el = builder .load (ptr )
212
228
if feature .is_categorical :
213
229
# first, check if the value is NaN
214
- is_nan = builder .fcmp_ordered ("uno" , el , dconst (0.0 ))
230
+ is_nan = builder .fcmp_ordered ("uno" , el , get_fdtype_const (0.0 , use_fp64 ))
215
231
# if it is, return smallest possible int (will always go right), else cast to int
216
232
el = builder .select (is_nan , iconst (- (2 ** 31 )), builder .fptosi (el , INT_CAT ))
217
233
args .append (el )
218
234
else :
219
235
args .append (el )
220
236
# iterate over each tree, sum up results
221
- results = [dconst (0.0 ) for _ in range (forest .n_classes )]
237
+ results = [get_fdtype_const (0.0 , use_fp64 ) for _ in range (forest .n_classes )]
222
238
for func in tree_funcs :
223
239
tree_res = builder .call (func .llvm_function , args )
224
240
results [func .class_id ] = builder .fadd (tree_res , results [func .class_id ])
@@ -243,6 +259,7 @@ def _populate_instruction_block(
243
259
forest .raw_score ,
244
260
forest .average_output ,
245
261
len (forest .trees ),
262
+ use_fp64 ,
246
263
)
247
264
for result , result_ptr in zip (results , results_ptr ):
248
265
builder .store (result , result_ptr )
@@ -252,7 +269,7 @@ def _populate_instruction_block(
252
269
# -- END CORE LOOP BLOCK
253
270
254
271
255
- def _populate_forest_func (forest , root_func , tree_funcs , fblocksize ):
272
+ def _populate_forest_func (forest , root_func , tree_funcs , fblocksize , use_fp64 ):
256
273
"""Populate root function IR for forest"""
257
274
258
275
assert fblocksize > 0
@@ -277,6 +294,7 @@ def _populate_forest_func(forest, root_func, tree_funcs, fblocksize):
277
294
setup_block ,
278
295
next_block ,
279
296
eval_objective_func ,
297
+ use_fp64 ,
280
298
)
281
299
282
300
@@ -288,28 +306,30 @@ def _populate_objective_func_block(
288
306
raw_score : bool ,
289
307
average_output : bool ,
290
308
num_trees : int ,
309
+ use_fp64 : bool ,
291
310
):
292
311
"""
293
312
Takes the objective function specification and generates the code for it into the builder
294
313
"""
295
- llvm_exp = builder .module .declare_intrinsic ("llvm.exp" , (DOUBLE ,))
296
- llvm_log = builder .module .declare_intrinsic ("llvm.log" , (DOUBLE ,))
314
+ DTYPE = get_fdtype (use_fp64 )
315
+ llvm_exp = builder .module .declare_intrinsic ("llvm.exp" , (DTYPE ,))
316
+ llvm_log = builder .module .declare_intrinsic ("llvm.log" , (DTYPE ,))
297
317
llvm_copysign = builder .module .declare_intrinsic (
298
- "llvm.copysign" , (DOUBLE , DOUBLE ), ir .FunctionType (DOUBLE , (DOUBLE , DOUBLE ))
318
+ "llvm.copysign" , (DTYPE , DTYPE ), ir .FunctionType (DTYPE , (DTYPE , DTYPE ))
299
319
)
300
320
301
321
if average_output :
302
- args [0 ] = builder .fdiv (args [0 ], dconst (num_trees ))
322
+ args [0 ] = builder .fdiv (args [0 ], get_fdtype_const (num_trees , use_fp64 ))
303
323
304
324
def _populate_sigmoid (alpha ):
305
325
if alpha <= 0 :
306
326
raise ValueError (f"Sigmoid parameter needs to be >0, is { alpha } " )
307
327
308
328
# 1 / (1 + exp(- alpha * x))
309
- inner = builder .fmul (dconst (- alpha ), args [0 ])
329
+ inner = builder .fmul (get_fdtype_const (- alpha , use_fp64 ), args [0 ])
310
330
exp = builder .call (llvm_exp , [inner ])
311
- denom = builder .fadd (dconst (1.0 ), exp )
312
- return builder .fdiv (dconst (1.0 ), denom )
331
+ denom = builder .fadd (get_fdtype_const (1.0 , use_fp64 ), exp )
332
+ return builder .fdiv (get_fdtype_const (1.0 , use_fp64 ), denom )
313
333
314
334
# raw score means we don't need to add the objective function
315
335
if raw_score :
@@ -324,7 +344,10 @@ def _populate_sigmoid(alpha):
324
344
# naive implementation which will be numerically unstable for small x.
325
345
# should be changed to log1p
326
346
exp = builder .call (llvm_exp , [args [0 ]])
327
- result = builder .call (llvm_log , [builder .fadd (dconst (1.0 ), exp )])
347
+ result = builder .call (
348
+ llvm_log , [builder .fadd (get_fdtype_const (1.0 , use_fp64 ), exp )]
349
+ )
350
+
328
351
elif objective in ("poisson" , "gamma" , "tweedie" ):
329
352
result = builder .call (llvm_exp , [args [0 ]])
330
353
elif objective in (
@@ -347,7 +370,7 @@ def _populate_sigmoid(alpha):
347
370
# TODO Might profit from vectorization, needs testing
348
371
result = [builder .call (llvm_exp , [arg ]) for arg in args ]
349
372
350
- denominator = dconst (0.0 )
373
+ denominator = get_fdtype_const (0.0 , use_fp64 )
351
374
for r in result :
352
375
denominator = builder .fadd (r , denominator )
353
376
@@ -391,11 +414,12 @@ def _populate_categorical_node_block(
391
414
return comp
392
415
393
416
394
- def _populate_numerical_node_block (func , builder , node ):
417
+ def _populate_numerical_node_block (func , builder , node , use_fp64 ):
395
418
"""populate block with IR for numerical node"""
396
419
val = func .args [node .split_feature ]
397
420
398
- thresh = ir .Constant (DOUBLE , node .threshold )
421
+ DTYPE = get_fdtype (use_fp64 )
422
+ thresh = ir .Constant (DTYPE , node .threshold )
399
423
missing_t = node .decision_type .missing_type
400
424
401
425
# If missingType != MNaN, LightGBM treats NaNs values as if they were 0.0.
@@ -417,7 +441,9 @@ def _populate_numerical_node_block(func, builder, node):
417
441
# unordered cmp: we'll get True (and go left) if any arg is qNaN
418
442
comp = builder .fcmp_unordered ("<=" , val , thresh )
419
443
else :
420
- is_missing = builder .fcmp_unordered ("==" , val , fconst (0.0 ))
444
+ is_missing = builder .fcmp_unordered (
445
+ "==" , val , get_fdtype_const (0.0 , use_fp64 )
446
+ )
421
447
less_eq = builder .fcmp_unordered ("<=" , val , thresh )
422
448
comp = builder .or_ (is_missing , less_eq )
423
449
else :
@@ -427,7 +453,9 @@ def _populate_numerical_node_block(func, builder, node):
427
453
# ordered cmp: we'll get False (and go right) if any arg is qNaN
428
454
comp = builder .fcmp_ordered ("<=" , val , thresh )
429
455
else :
430
- is_missing = builder .fcmp_unordered ("==" , val , fconst (0.0 ))
456
+ is_missing = builder .fcmp_unordered (
457
+ "==" , val , get_fdtype_const (0.0 , use_fp64 )
458
+ )
431
459
greater = builder .fcmp_ordered (">" , val , thresh )
432
460
comp = builder .not_ (builder .or_ (is_missing , greater ))
433
461
return comp
0 commit comments