Skip to content

Commit

Permalink
[feature] allow finer control over embedding variable (#197)
Browse files Browse the repository at this point in the history
Co-authored-by: 杨熙 <mengli.cml@alibaba-inc.com>
  • Loading branch information
chengmengli06 and 杨熙 authored May 26, 2022
1 parent dc452db commit 9431295
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 30 deletions.
57 changes: 28 additions & 29 deletions easy_rec/python/feature_column/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from easy_rec.python.compat.feature_column import sequence_feature_column
from easy_rec.python.protos.feature_config_pb2 import FeatureConfig
from easy_rec.python.protos.feature_config_pb2 import WideOrDeep
from easy_rec.python.utils.proto_util import copy_obj

from easy_rec.python.compat.feature_column import feature_column_v2 as feature_column # NOQA

Expand Down Expand Up @@ -66,28 +67,24 @@ def __init__(self,
self._use_embedding_variable = use_embedding_variable
self._vocab_size = {}

def _cmp_embed_config(a, b):
return a.embedding_dim == b.embedding_dim and a.combiner == b.combiner and\
a.initializer == b.initializer and a.max_partitions == b.max_partitions and\
a.use_embedding_variable == b.use_embedding_variable

for config in self._feature_configs:
if not config.HasField('embedding_name'):
continue
embed_name = config.embedding_name
embed_info = {
'embedding_dim':
config.embedding_dim,
'combiner':
config.combiner,
'initializer':
config.initializer if config.HasField('initializer') else None,
'max_partitions':
config.max_partitions
}

if embed_name in self._share_embed_names:
assert embed_info == self._share_embed_infos[embed_name], \
assert _cmp_embed_config(config, self._share_embed_infos[embed_name]),\
'shared embed info of [%s] is not matched [%s] vs [%s]' % (
embed_name, embed_info, self._share_embed_infos[embed_name])
embed_name, config, self._share_embed_infos[embed_name])
self._share_embed_names[embed_name] += 1
else:
self._share_embed_names[embed_name] = 1
self._share_embed_infos[embed_name] = embed_info
self._share_embed_infos[embed_name] = copy_obj(config)

# remove not shared embedding names
not_shared = [
Expand Down Expand Up @@ -133,20 +130,22 @@ def __init__(self,

for embed_name in self._share_embed_names:
initializer = None
if self._share_embed_infos[embed_name]['initializer']:
if self._share_embed_infos[embed_name].HasField('initializer'):
initializer = hyperparams_builder.build_initializer(
self._share_embed_infos[embed_name]['initializer'])
self._share_embed_infos[embed_name].initializer)
partitioner = self._build_partitioner(
self._share_embed_infos[embed_name]['max_partitions'])
self._share_embed_infos[embed_name])
use_ev = self._use_embedding_variable or \
self._share_embed_infos[embed_name].use_embedding_variable
# for handling share embedding columns
share_embed_fcs = feature_column.shared_embedding_columns(
self._deep_share_embed_columns[embed_name],
self._share_embed_infos[embed_name]['embedding_dim'],
self._share_embed_infos[embed_name].embedding_dim,
initializer=initializer,
shared_embedding_collection_name=embed_name,
combiner=self._share_embed_infos[embed_name]['combiner'],
combiner=self._share_embed_infos[embed_name].combiner,
partitioner=partitioner,
use_embedding_variable=self._use_embedding_variable)
use_embedding_variable=use_ev)
self._deep_share_embed_columns[embed_name] = share_embed_fcs
# for handling wide share embedding columns
if len(self._wide_share_embed_columns[embed_name]) == 0:
Expand All @@ -158,7 +157,7 @@ def __init__(self,
shared_embedding_collection_name=embed_name + '_wide',
combiner='sum',
partitioner=partitioner,
use_embedding_variable=self._use_embedding_variable)
use_embedding_variable=use_ev)
self._wide_share_embed_columns[embed_name] = share_embed_fcs

for fc_name in self._deep_columns:
Expand Down Expand Up @@ -475,13 +474,13 @@ def parse_sequence_feature(self, config):
else:
self._sequence_columns[feature_name] = fc

def _build_partitioner(self, max_partitions):
if max_partitions > 1:
if self._use_embedding_variable:
def _build_partitioner(self, config):
if config.max_partitions > 1:
if self._use_embedding_variable or config.use_embedding_variable:
# pai embedding_variable should use fixed_size_partitioner
return tf.fixed_size_partitioner(num_shards=max_partitions)
return tf.fixed_size_partitioner(num_shards=config.max_partitions)
else:
return min_max_variable_partitioner(max_partitions=max_partitions)
return min_max_variable_partitioner(max_partitions=config.max_partitions)
else:
return None

Expand Down Expand Up @@ -523,8 +522,8 @@ def _add_wide_embedding_column(self, fc, config):
self._wide_output_dim,
combiner='sum',
initializer=initializer,
partitioner=self._build_partitioner(config.max_partitions),
use_embedding_variable=self._use_embedding_variable)
partitioner=self._build_partitioner(config),
use_embedding_variable=self._use_embedding_variable or config.use_embedding_variable)
self._wide_columns[feature_name] = wide_fc

def _add_deep_embedding_column(self, fc, config):
Expand All @@ -543,8 +542,8 @@ def _add_deep_embedding_column(self, fc, config):
config.embedding_dim,
combiner=config.combiner,
initializer=initializer,
partitioner=self._build_partitioner(config.max_partitions),
use_embedding_variable=self._use_embedding_variable)
partitioner=self._build_partitioner(config),
use_embedding_variable=self._use_embedding_variable or config.use_embedding_variable)
if config.feature_type != config.SequenceFeature:
self._deep_columns[feature_name] = fc
else:
Expand Down
3 changes: 3 additions & 0 deletions easy_rec/python/protos/feature_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ message FeatureConfig {

// for expr feature
optional string expression = 30;

// use embedding variables
optional bool use_embedding_variable = 31 [default=false];
}

message FeatureConfigV2 {
Expand Down
5 changes: 4 additions & 1 deletion easy_rec/python/utils/fg_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def load_fg_json_to_config(pipeline_config):

pipeline_config.data_config.ClearField('input_fields')
pipeline_config.ClearField('feature_configs')
pipeline_config.feature_config.ClearField('features')

# not clear features so that we could define extra features
# which is not defined in fg.json
# pipeline_config.feature_config.ClearField('features')

for input_config in fg_config.data_config.input_fields:
in_config = DatasetConfig.Field()
Expand Down

0 comments on commit 9431295

Please sign in to comment.