Skip to content

Commit 7b38bac

Browse files
authored
Merge pull request #83 from siboehm/siboehm/safeSoftmax
safe softmax
2 parents a82ec55 + e50798e commit 7b38bac

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

lleaves/compiler/codegen/codegen.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -367,14 +367,20 @@ def _populate_sigmoid(alpha):
367367
result = args[0]
368368
elif objective == "multiclass":
369369
assert len(args)
370-
# TODO Might profit from vectorization, needs testing
371-
result = [builder.call(llvm_exp, [arg]) for arg in args]
372-
370+
# stable softmax
371+
max_val = args[0]
372+
for arg in args[1:]:
373+
max_val = builder.select(
374+
builder.fcmp_ordered(">", arg, max_val), arg, max_val
375+
)
376+
exp_vals = [
377+
builder.call(llvm_exp, [builder.fsub(arg, max_val)]) for arg in args
378+
]
373379
denominator = get_fdtype_const(0.0, use_fp64)
374-
for r in result:
375-
denominator = builder.fadd(r, denominator)
376-
377-
result = [builder.fdiv(r, denominator) for r in result]
380+
for exp_val in exp_vals:
381+
denominator = builder.fadd(exp_val, denominator)
382+
denominator = builder.fadd(denominator, get_fdtype_const(1e-15, use_fp64))
383+
result = [builder.fdiv(exp_val, denominator) for exp_val in exp_vals]
378384
else:
379385
raise ValueError(
380386
f"Objective '{objective}' not yet implemented. {ISSUE_ERROR_MSG}"

0 commit comments

Comments
 (0)