Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] use dataclass for CallbackEnv #6048

Merged
merged 4 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/test-python-oldest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#
echo "installing lightgbm's dependencies"
pip install \
'dataclasses' \
'numpy==1.12.0' \
'pandas==0.24.0' \
'scikit-learn==0.18.2' \
Expand Down
28 changes: 16 additions & 12 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# coding: utf-8
"""Callbacks library."""
import collections
from collections import OrderedDict
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
from .basic import Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning

if TYPE_CHECKING:
from .engine import CVBooster

__all__ = [
'early_stopping',
Expand Down Expand Up @@ -43,14 +47,14 @@ def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) ->


# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
"CallbackEnv",
["model",
"params",
"iteration",
"begin_iteration",
"end_iteration",
"evaluation_result_list"])
@dataclass
class CallbackEnv:
model: Union[Booster, "CVBooster"]
params: Dict[str, Any]
iteration: int
begin_iteration: int
end_iteration: int
evaluation_result_list: Optional[List[_LGBM_BoosterEvalMethodResultType]]


def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
Expand Down Expand Up @@ -126,7 +130,7 @@ def _init(self, env: CallbackEnv) -> None:
data_name, eval_name = item[:2]
else: # cv
data_name, eval_name = item[1].split()
self.eval_result.setdefault(data_name, collections.OrderedDict())
self.eval_result.setdefault(data_name, OrderedDict())
if len(item) == 4:
self.eval_result[data_name].setdefault(eval_name, [])
else:
Expand Down
8 changes: 4 additions & 4 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# coding: utf-8
"""Library with training routines of LightGBM."""
import collections
import copy
import json
from collections import OrderedDict, defaultdict
from operator import attrgetter
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -293,7 +293,7 @@ def train(
booster.best_iteration = earlyStopException.best_iteration + 1
evaluation_result_list = earlyStopException.best_score
break
booster.best_score = collections.defaultdict(collections.OrderedDict)
booster.best_score = defaultdict(OrderedDict)
for dataset_name, eval_name, score, _ in evaluation_result_list:
booster.best_score[dataset_name][eval_name] = score
if not keep_training_booster:
Expand Down Expand Up @@ -526,7 +526,7 @@ def _agg_cv_result(
raw_results: List[List[Tuple[str, str, float, bool]]]
) -> List[Tuple[str, str, float, bool, float]]:
"""Aggregate cross-validation results."""
cvmap: Dict[str, List[float]] = collections.OrderedDict()
cvmap: Dict[str, List[float]] = OrderedDict()
metric_type: Dict[str, bool] = {}
for one_result in raw_results:
for one_line in one_result:
Expand Down Expand Up @@ -717,7 +717,7 @@ def cv(
.set_feature_name(feature_name) \
.set_categorical_feature(categorical_feature)

results = collections.defaultdict(list)
results = defaultdict(list)
cvfolds = _make_n_folds(full_data=train_set, folds=folds, nfold=nfold,
params=params, seed=seed, fpreproc=fpreproc,
stratified=stratified, shuffle=shuffle,
Expand Down
1 change: 1 addition & 0 deletions python-package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence"
]
dependencies = [
"dataclasses ; python_version < '3.7'",
"numpy",
"scipy"
]
Expand Down