diff --git a/easy_rec/python/feature_column/feature_column.py b/easy_rec/python/feature_column/feature_column.py index b6720a815..438df823e 100644 --- a/easy_rec/python/feature_column/feature_column.py +++ b/easy_rec/python/feature_column/feature_column.py @@ -9,6 +9,7 @@ from easy_rec.python.compat.feature_column import sequence_feature_column from easy_rec.python.protos.feature_config_pb2 import FeatureConfig from easy_rec.python.protos.feature_config_pb2 import WideOrDeep +from easy_rec.python.utils.proto_util import copy_obj from easy_rec.python.compat.feature_column import feature_column_v2 as feature_column # NOQA @@ -66,28 +67,24 @@ def __init__(self, self._use_embedding_variable = use_embedding_variable self._vocab_size = {} + def _cmp_embed_config(a, b): + return a.embedding_dim == b.embedding_dim and a.combiner == b.combiner and\ + a.initializer == b.initializer and a.max_partitions == b.max_partitions and\ + a.use_embedding_variable == b.use_embedding_variable + for config in self._feature_configs: if not config.HasField('embedding_name'): continue embed_name = config.embedding_name - embed_info = { - 'embedding_dim': - config.embedding_dim, - 'combiner': - config.combiner, - 'initializer': - config.initializer if config.HasField('initializer') else None, - 'max_partitions': - config.max_partitions - } + if embed_name in self._share_embed_names: - assert embed_info == self._share_embed_infos[embed_name], \ + assert _cmp_embed_config(config, self._share_embed_infos[embed_name]),\ 'shared embed info of [%s] is not matched [%s] vs [%s]' % ( - embed_name, embed_info, self._share_embed_infos[embed_name]) + embed_name, config, self._share_embed_infos[embed_name]) self._share_embed_names[embed_name] += 1 else: self._share_embed_names[embed_name] = 1 - self._share_embed_infos[embed_name] = embed_info + self._share_embed_infos[embed_name] = copy_obj(config) # remove not shared embedding names not_shared = [ @@ -133,20 +130,22 @@ def __init__(self, for embed_name in self._share_embed_names: initializer = None - if self._share_embed_infos[embed_name]['initializer']: + if self._share_embed_infos[embed_name].HasField('initializer'): initializer = hyperparams_builder.build_initializer( - self._share_embed_infos[embed_name]['initializer']) + self._share_embed_infos[embed_name].initializer) partitioner = self._build_partitioner( - self._share_embed_infos[embed_name]['max_partitions']) + self._share_embed_infos[embed_name]) + use_ev = self._use_embedding_variable or \ + self._share_embed_infos[embed_name].use_embedding_variable # for handling share embedding columns share_embed_fcs = feature_column.shared_embedding_columns( self._deep_share_embed_columns[embed_name], - self._share_embed_infos[embed_name]['embedding_dim'], + self._share_embed_infos[embed_name].embedding_dim, initializer=initializer, shared_embedding_collection_name=embed_name, - combiner=self._share_embed_infos[embed_name]['combiner'], + combiner=self._share_embed_infos[embed_name].combiner, partitioner=partitioner, - use_embedding_variable=self._use_embedding_variable) + use_embedding_variable=use_ev) self._deep_share_embed_columns[embed_name] = share_embed_fcs # for handling wide share embedding columns if len(self._wide_share_embed_columns[embed_name]) == 0: @@ -158,7 +157,7 @@ def __init__(self, shared_embedding_collection_name=embed_name + '_wide', combiner='sum', partitioner=partitioner, - use_embedding_variable=self._use_embedding_variable) + use_embedding_variable=use_ev) self._wide_share_embed_columns[embed_name] = share_embed_fcs for fc_name in self._deep_columns: @@ -475,13 +474,13 @@ def parse_sequence_feature(self, config): else: self._sequence_columns[feature_name] = fc - def _build_partitioner(self, max_partitions): - if max_partitions > 1: - if self._use_embedding_variable: + def _build_partitioner(self, config): + if config.max_partitions > 1: + if self._use_embedding_variable or config.use_embedding_variable: # pai embedding_variable should use fixed_size_partitioner - return tf.fixed_size_partitioner(num_shards=max_partitions) + return tf.fixed_size_partitioner(num_shards=config.max_partitions) else: - return min_max_variable_partitioner(max_partitions=max_partitions) + return min_max_variable_partitioner(max_partitions=config.max_partitions) else: return None @@ -523,8 +522,8 @@ def _add_wide_embedding_column(self, fc, config): self._wide_output_dim, combiner='sum', initializer=initializer, - partitioner=self._build_partitioner(config.max_partitions), - use_embedding_variable=self._use_embedding_variable) + partitioner=self._build_partitioner(config), + use_embedding_variable=self._use_embedding_variable or config.use_embedding_variable) self._wide_columns[feature_name] = wide_fc def _add_deep_embedding_column(self, fc, config): @@ -543,8 +542,8 @@ def _add_deep_embedding_column(self, fc, config): config.embedding_dim, combiner=config.combiner, initializer=initializer, - partitioner=self._build_partitioner(config.max_partitions), - use_embedding_variable=self._use_embedding_variable) + partitioner=self._build_partitioner(config), + use_embedding_variable=self._use_embedding_variable or config.use_embedding_variable) if config.feature_type != config.SequenceFeature: self._deep_columns[feature_name] = fc else: diff --git a/easy_rec/python/protos/feature_config.proto b/easy_rec/python/protos/feature_config.proto index e8fbb4d56..2120b7b79 100644 --- a/easy_rec/python/protos/feature_config.proto +++ b/easy_rec/python/protos/feature_config.proto @@ -113,6 +113,9 @@ message FeatureConfig { // for expr feature optional string expression = 30; + + // use embedding variables + optional bool use_embedding_variable = 31 [default=false]; } message FeatureConfigV2 { diff --git a/easy_rec/python/utils/fg_util.py b/easy_rec/python/utils/fg_util.py index fb5694287..040277af1 100644 --- a/easy_rec/python/utils/fg_util.py +++ b/easy_rec/python/utils/fg_util.py @@ -26,7 +26,10 @@ def load_fg_json_to_config(pipeline_config): pipeline_config.data_config.ClearField('input_fields') pipeline_config.ClearField('feature_configs') - pipeline_config.feature_config.ClearField('features') + + # not clear features so that we could define extra features + # which is not defined in fg.json + # pipeline_config.feature_config.ClearField('features') for input_config in fg_config.data_config.input_fields: in_config = DatasetConfig.Field()