Skip to content

Commit

Permalink
fix hpo param convert bug: #155 (#198)
Browse files Browse the repository at this point in the history
Co-authored-by: 杨熙 <mengli.cml@alibaba-inc.com>
  • Loading branch information
chengmengli06 and 杨熙 authored May 25, 2022
1 parent 8219ded commit dc452db
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 17 deletions.
17 changes: 17 additions & 0 deletions easy_rec/python/test/hpo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from easy_rec.python.utils import config_util
from easy_rec.python.utils import hpo_util
from easy_rec.python.utils import test_utils
from easy_rec.python.protos.feature_config_pb2 import FeatureConfig

if tf.__version__ >= '2.0':
gfile = tf.compat.v1.gfile
Expand Down Expand Up @@ -198,6 +199,22 @@ def test_edit_config_v12(self):
assert len(tmp_fea.boundaries) == 25
assert np.abs(tmp_fea.boundaries[1] - 21.0) < 1e-5

def test_edit_config_v13(self):
tmp_file = 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config'
tmp_config = config_util.get_configs_from_pipeline_file(tmp_file)
tmp_file = 'samples/hpo/hpo_param_v13.json'
tmp_config = config_util.edit_config(tmp_config, self.load_config(tmp_file))
assert not tmp_config.export_config.multi_placeholder

def test_edit_config_v14(self):
tmp_file = 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config'
tmp_config = config_util.get_configs_from_pipeline_file(tmp_file)
tmp_file = 'samples/hpo/hpo_param_v14.json'
tmp_config = config_util.edit_config(tmp_config, self.load_config(tmp_file))
for i, tmp_fea in enumerate(tmp_config.feature_configs):
if tmp_fea.input_names[0] == 'hour':
assert len(tmp_fea.feature_type) == FeatureConfig.RawFeature

def test_save_eval_metrics_with_env(self):
os.environ['TF_CONFIG'] = """
{ "cluster": {
Expand Down
41 changes: 24 additions & 17 deletions easy_rec/python/utils/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,22 @@ def edit_config(pipeline_config, edit_config_json):
edit_config_json: edit config json
"""

def _type_convert(proto, val, parent=None):
if type(val) != type(proto):
try:
if isinstance(proto, bool):
assert val in ['True', 'true', 'False', 'false']
val = val in ['True', 'true']
else:
val = type(proto)(val)
except ValueError as ex:
if parent is None:
raise ex
assert isinstance(proto, int)
val = getattr(parent, val)
assert isinstance(val, int)
return val

def _get_attr(obj, attr, only_last=False):
# only_last means we only return the last element in paths array
attr_toks = [x.strip() for x in attr.split('.') if x != '']
Expand Down Expand Up @@ -238,14 +254,9 @@ def _get_attr(obj, attr, only_last=False):
for tid, update_obj in enumerate(update_objs):
tmp, tmp_parent, _, _ = _get_attr(
update_obj, cond_key, only_last=True)
if type(cond_val) != type(tmp):
try:
cond_val = type(tmp)(cond_val)
except ValueError:
# to support for enumerations like IdFeature
assert isinstance(tmp, int)
cond_val = getattr(tmp_parent, cond_val)
assert isinstance(cond_val, int)

cond_val = _type_convert(tmp, cond_val, tmp_parent)

if op_func(tmp, cond_val):
obj_id = tid
paths.append((update_obj, update_objs, None, obj_id))
Expand Down Expand Up @@ -275,15 +286,11 @@ def _get_attr(obj, attr, only_last=False):
basic_types = [int, str, float, bool, type(u'')]
if type(tmp_val) in basic_types:
# simple type cast
try:
tmp_val = type(tmp_val)(param_val)
if tmp_name is None:
tmp_obj[tmp_id] = tmp_val
else:
setattr(tmp_obj, tmp_name, tmp_val)
except ValueError:
# for enumeration types
text_format.Merge('%s:%s' % (tmp_name, param_val), tmp_obj)
tmp_val = _type_convert(tmp_val, param_val, tmp_obj)
if tmp_name is None:
tmp_obj[tmp_id] = tmp_val
else:
setattr(tmp_obj, tmp_name, tmp_val)
elif 'Scalar' in str(type(tmp_val)) and 'ClearField' in dir(tmp_obj):
tmp_obj.ClearField(tmp_name)
text_format.Parse('%s:%s' % (tmp_name, param_val), tmp_obj)
Expand Down
5 changes: 5 additions & 0 deletions samples/hpo/hpo_param_v13.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"param": {
"export_config.multi_placeholder": "false"
}
}
5 changes: 5 additions & 0 deletions samples/hpo/hpo_param_v14.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"param": {
"feature_config.features[input_names[0]=hour].feature_type": "RawFeature"
}
}

0 comments on commit dc452db

Please sign in to comment.