diff --git a/easy_rec/python/test/hpo_test.py b/easy_rec/python/test/hpo_test.py index d3fc60a71..983081e49 100644 --- a/easy_rec/python/test/hpo_test.py +++ b/easy_rec/python/test/hpo_test.py @@ -12,6 +12,7 @@ from easy_rec.python.utils import config_util from easy_rec.python.utils import hpo_util from easy_rec.python.utils import test_utils +from easy_rec.python.protos.feature_config_pb2 import FeatureConfig if tf.__version__ >= '2.0': gfile = tf.compat.v1.gfile @@ -198,6 +199,22 @@ def test_edit_config_v12(self): assert len(tmp_fea.boundaries) == 25 assert np.abs(tmp_fea.boundaries[1] - 21.0) < 1e-5 + def test_edit_config_v13(self): + tmp_file = 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config' + tmp_config = config_util.get_configs_from_pipeline_file(tmp_file) + tmp_file = 'samples/hpo/hpo_param_v13.json' + tmp_config = config_util.edit_config(tmp_config, self.load_config(tmp_file)) + assert not tmp_config.export_config.multi_placeholder + + def test_edit_config_v14(self): + tmp_file = 'samples/model_config/deepfm_multi_cls_on_avazu_ctr.config' + tmp_config = config_util.get_configs_from_pipeline_file(tmp_file) + tmp_file = 'samples/hpo/hpo_param_v14.json' + tmp_config = config_util.edit_config(tmp_config, self.load_config(tmp_file)) + for i, tmp_fea in enumerate(tmp_config.feature_configs): + if tmp_fea.input_names[0] == 'hour': + assert len(tmp_fea.feature_type) == FeatureConfig.RawFeature + def test_save_eval_metrics_with_env(self): os.environ['TF_CONFIG'] = """ { "cluster": { diff --git a/easy_rec/python/utils/config_util.py b/easy_rec/python/utils/config_util.py index 65a0df56f..a2e73b8c0 100644 --- a/easy_rec/python/utils/config_util.py +++ b/easy_rec/python/utils/config_util.py @@ -163,6 +163,22 @@ def edit_config(pipeline_config, edit_config_json): edit_config_json: edit config json """ + def _type_convert(proto, val, parent=None): + if type(val) != type(proto): + try: + if isinstance(proto, bool): + assert val in ['True', 'true', 'False', 'false'] + val = val in ['True', 'true'] + else: + val = type(proto)(val) + except ValueError as ex: + if parent is None: + raise ex + assert isinstance(proto, int) + val = getattr(parent, val) + assert isinstance(val, int) + return val + def _get_attr(obj, attr, only_last=False): # only_last means we only return the last element in paths array attr_toks = [x.strip() for x in attr.split('.') if x != ''] @@ -238,14 +254,9 @@ def _get_attr(obj, attr, only_last=False): for tid, update_obj in enumerate(update_objs): tmp, tmp_parent, _, _ = _get_attr( update_obj, cond_key, only_last=True) - if type(cond_val) != type(tmp): - try: - cond_val = type(tmp)(cond_val) - except ValueError: - # to support for enumerations like IdFeature - assert isinstance(tmp, int) - cond_val = getattr(tmp_parent, cond_val) - assert isinstance(cond_val, int) + + cond_val = _type_convert(tmp, cond_val, tmp_parent) + if op_func(tmp, cond_val): obj_id = tid paths.append((update_obj, update_objs, None, obj_id)) @@ -275,15 +286,11 @@ def _get_attr(obj, attr, only_last=False): basic_types = [int, str, float, bool, type(u'')] if type(tmp_val) in basic_types: # simple type cast - try: - tmp_val = type(tmp_val)(param_val) - if tmp_name is None: - tmp_obj[tmp_id] = tmp_val - else: - setattr(tmp_obj, tmp_name, tmp_val) - except ValueError: - # for enumeration types - text_format.Merge('%s:%s' % (tmp_name, param_val), tmp_obj) + tmp_val = _type_convert(tmp_val, param_val, tmp_obj) + if tmp_name is None: + tmp_obj[tmp_id] = tmp_val + else: + setattr(tmp_obj, tmp_name, tmp_val) elif 'Scalar' in str(type(tmp_val)) and 'ClearField' in dir(tmp_obj): tmp_obj.ClearField(tmp_name) text_format.Parse('%s:%s' % (tmp_name, param_val), tmp_obj) diff --git a/samples/hpo/hpo_param_v13.json b/samples/hpo/hpo_param_v13.json new file mode 100755 index 000000000..9acf7acec --- /dev/null +++ b/samples/hpo/hpo_param_v13.json @@ -0,0 +1,5 @@ +{ + "param": { + "export_config.multi_placeholder": "false" + } +} diff --git a/samples/hpo/hpo_param_v14.json b/samples/hpo/hpo_param_v14.json new file mode 100755 index 000000000..b4af1c916 --- /dev/null +++ b/samples/hpo/hpo_param_v14.json @@ -0,0 +1,5 @@ +{ + "param": { + "feature_config.features[input_names[0]=hour].feature_type": "RawFeature" + } +}