Skip to content

Commit 63a26d0

Browse files
siboehmchenglin
and
chenglin
authored
extract_pandas_traintime_categories: return [] if pandas_categorical is null in model file (#14)
* extract_pandas_traintime_categories: return empty list if pandas_categorical is null in model file * Test: Prediction for df with empty categoricals Co-authored-by: chenglin <chenglin.wang@amh-group.com> Co-authored-by: Simon Boehm <simon@siboehm.com>
2 parents 26e39d6 + 54e07cd commit 63a26d0

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

lleaves/data_processing.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,10 @@ def extract_pandas_traintime_categories(file_path):
145145
if not last_line.startswith(pandas_key):
146146
last_line = lines[-2].decode().strip()
147147
if last_line.startswith(pandas_key):
148-
return json.loads(last_line[len(pandas_key) :])
148+
pandas_categorical = json.loads(last_line[len(pandas_key) :])
149+
if pandas_categorical is None:
150+
pandas_categorical = []
151+
return pandas_categorical
149152
raise ValueError("Ill formatted model file!")
150153

151154

tests/test_dataprocessing.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_parsing_pandas(tmp_path):
2828
file.writelines(lines)
2929

3030
pandas_categorical = extract_pandas_traintime_categories(model_file)
31-
assert pandas_categorical is None
31+
assert pandas_categorical == []
3232
pandas_categorical = extract_pandas_traintime_categories(mod_model_file)
3333
assert pandas_categorical == [
3434
["a", "b", "c"],
@@ -106,3 +106,16 @@ def test_sliced_arrays():
106106
llvm_model.predict(sliced, n_jobs=4), lgbm_model.predict(sliced), decimal=13
107107
)
108108
return
109+
110+
111+
def test_pd_empty_categories():
112+
# this model has `pandas_categorical:null`
113+
llvm_model = Model(model_file="tests/models/tiniest_single_tree/model.txt")
114+
llvm_model.compile()
115+
lgbm_model = Booster(model_file="tests/models/tiniest_single_tree/model.txt")
116+
df = pd.DataFrame(
117+
{str(i): list(range(10)) for i in range(llvm_model.num_feature())}
118+
)
119+
np.testing.assert_almost_equal(
120+
llvm_model.predict(df), lgbm_model.predict(df), decimal=13
121+
)

0 commit comments

Comments
 (0)