Skip to content

Commit 58fb88a

Browse files
committed
Replaced deprecated functions in benchmark.py
1 parent a82ec55 commit 58fb88a

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

benchmarks/benchmark.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
import onnxruntime as rt
88
import pandas as pd
99
import treelite
10-
import treelite_runtime
10+
import tl2cgen
1111
from onnxconverter_common import FloatTensorType
1212

13-
from benchmarks.train_NYC_model import feature_enginering
13+
from train_NYC_model import feature_enginering
1414
from lleaves import Model
1515

1616

@@ -58,20 +58,21 @@ def _setup(self, data, n_threads):
5858
# disable thread pinning, which modifies (and never resets!) process-global pthreads state
5959
os.environ["TREELITE_BIND_THREADS"] = "0"
6060
treelite_model = treelite.Model.load(self.model_file, model_format="lightgbm")
61-
treelite_model.export_lib(
61+
tl2cgen.export_lib(
62+
model=treelite_model,
6263
toolchain="gcc",
6364
libpath="/tmp/treelite_model.so",
6465
params={"parallel_comp": 4},
6566
verbose=False,
6667
)
67-
self.model = treelite_runtime.Predictor(
68+
self.model = tl2cgen.Predictor(
6869
"/tmp/treelite_model.so",
6970
nthread=n_threads,
7071
)
7172

7273
def predict(self, data, index, batchsize, n_threads):
7374
return self.model.predict(
74-
treelite_runtime.DMatrix(data[index : index + batchsize])
75+
tl2cgen.DMatrix(data[index : index + batchsize])
7576
)
7677

7778

0 commit comments

Comments
 (0)