Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/optimize dssm_senet #521

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions docs/source/models/dssm_derivatives.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ model_config:{
user_tower {
id: "user_id"
senet {
num_squeeze_group : 2
reduction_ratio: 4
num_squeeze_group : 1
reduction_ratio: 2
excitation_acitvation: 'relu'
}
dnn {
hidden_units: [128, 32]
Expand All @@ -53,8 +54,9 @@ model_config:{
item_tower {
id: "adgroup_id"
senet {
num_squeeze_group : 2
reduction_ratio: 4
num_squeeze_group : 1
reduction_ratio: 2
excitation_acitvation: 'relu'
}
dnn {
hidden_units: [128, 32]
Expand All @@ -71,15 +73,14 @@ model_config:{
```

- senet参数配置:
- num_squeeze_group: 每个特征embedding的分组个数, 默认为2
- reduction_ratio: 维度压缩比例, 默认为4
- num_squeeze_group: 每个特征embedding的分组个数, 默认为1
- reduction_ratio: 维度压缩比例, 默认为2
- excitation_acitvation: excitation weights layer激活函数,默认为'relu'

### 示例Config

[dssm_senet_on_taobao.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/dssm_senet_on_taobao.config)

[dssm_senet_on_taobao_backbone.config](https://github.com/alibaba/EasyRec/tree/master/samples/model_config/dssm_senet_on_taobao_backbone.config)

### 参考论文

[Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507)
Expand Down
20 changes: 12 additions & 8 deletions easy_rec/python/layers/senet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,33 +25,36 @@ def __init__(self,
num_fields,
num_squeeze_group,
reduction_ratio,
excitation_acitvation,
l2_reg,
name='SENet'):
self.num_fields = num_fields
self.num_squeeze_group = num_squeeze_group
self.reduction_ratio = reduction_ratio
self.excitation_acitvation = excitation_acitvation
self._l2_reg = l2_reg
self._name = name

def __call__(self, inputs):
g = self.num_squeeze_group
f = self.num_fields
r = self.reduction_ratio
reduction_size = max(1, f * g * 2 // r)
g = self.num_squeeze_group

emb_size = 0
for input in inputs:
emb_size += int(input.shape[-1])

group_embs = [
tf.reshape(emb, [-1, g, int(emb.shape[-1]) // g]) for emb in inputs
]
group_embs = []
for emb in inputs:
g_dim = max(2, int(emb.shape[-1]) // g)
ghat = emb.shape[-1] // g_dim
group_embs.append(tf.reshape(emb, [-1, ghat, g_dim]))

squeezed = []
for emb in group_embs:
squeezed.append(tf.reduce_max(emb, axis=-1)) # [B, g]
squeezed.append(tf.reduce_mean(emb, axis=-1)) # [B, g]
z = tf.concat(squeezed, axis=1) # [bs, field_size * num_groups * 2]
z = tf.concat(squeezed, axis=1) # [bs, num_groups*field_size]

reduction_size = max(1, z.shape[-1] // r)

reduced = tf.layers.dense(
inputs=z,
Expand All @@ -64,6 +67,7 @@ def __call__(self, inputs):
inputs=reduced,
units=emb_size,
kernel_initializer='glorot_normal',
activation=self.excitation_acitvation,
name='%s/excite' % self._name)

# Re-weight
Expand Down
13 changes: 13 additions & 0 deletions easy_rec/python/model/dssm_senet.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,18 @@ def build_predict_graph(self):
num_fields=self.user_num_fields,
num_squeeze_group=self.user_tower.senet.num_squeeze_group,
reduction_ratio=self.user_tower.senet.reduction_ratio,
excitation_acitvation=self.user_tower.senet.excitation_acitvation,
l2_reg=self._l2_reg,
name='user_senet')
user_senet_output_list = user_senet(self.user_feature_list)
user_senet_output = tf.concat(user_senet_output_list, axis=-1)

user_senet_output = tf.layers.batch_normalization(
user_senet_output,
training=self._is_training,
trainable=True,
name='user_senet_bn')

num_user_dnn_layer = len(self.user_tower.dnn.hidden_units)
last_user_hidden = self.user_tower.dnn.hidden_units.pop()
user_dnn = dnn.DNN(self.user_tower.dnn, self._l2_reg, 'user_dnn',
Expand All @@ -76,11 +83,17 @@ def build_predict_graph(self):
num_fields=self.item_num_fields,
num_squeeze_group=self.item_tower.senet.num_squeeze_group,
reduction_ratio=self.item_tower.senet.reduction_ratio,
excitation_acitvation=self.item_tower.senet.excitation_acitvation,
l2_reg=self._l2_reg,
name='item_senet')

item_senet_output_list = item_senet(self.item_feature_list)
item_senet_output = tf.concat(item_senet_output_list, axis=-1)
item_senet_output = tf.layers.batch_normalization(
item_senet_output,
training=self._is_training,
trainable=True,
name='item_senet_bn')

num_item_dnn_layer = len(self.item_tower.dnn.hidden_units)
last_item_hidden = self.item_tower.dnn.hidden_units.pop()
Expand Down
11 changes: 9 additions & 2 deletions easy_rec/python/protos/dssm_senet.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,18 @@ package protos;

import "easy_rec/python/protos/dnn.proto";
import "easy_rec/python/protos/simi.proto";
import "easy_rec/python/protos/layer.proto";


message SENet_Module {
required uint32 reduction_ratio = 1 [default = 2];
optional uint32 num_squeeze_group = 2 [default = 1];
optional string excitation_acitvation = 3 [default = 'relu'];
}


message DSSM_SENet_Tower {
required string id = 1;
required SENet senet = 2;
required SENet_Module senet = 2;
required DNN dnn = 3;

};
Expand Down
10 changes: 6 additions & 4 deletions examples/configs/dssm_senet_on_taobao.config
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,9 @@ model_config:{
user_tower {
id: "user_id"
senet {
num_squeeze_group : 2
reduction_ratio: 4
num_squeeze_group : 1
reduction_ratio: 2
excitation_acitvation: 'relu'
}
dnn {
hidden_units: [256, 128, 64, 32]
Expand All @@ -270,8 +271,9 @@ model_config:{
item_tower {
id: "adgroup_id"
senet {
num_squeeze_group : 2
reduction_ratio: 4
num_squeeze_group : 1
reduction_ratio: 2
excitation_acitvation: 'relu'
}
dnn {
hidden_units: [256, 128, 64, 32]
Expand Down
2 changes: 1 addition & 1 deletion git-lfs/git_lfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def get_yes_no(msg):
'usage: python git_lfs.py [pull] [push] [add filename] [resolve_conflict]'
)
sys.exit(1)
home_directory = os.path.expanduser("~")
home_directory = os.path.expanduser('~')
with open('.git_oss_config_pub', 'r') as fin:
git_oss_data_dir = None
host = None
Expand Down