Skip to content

Commit 70cae32

Browse files
committed
fix optimizer choice, use same defalut params for train_gpu
1 parent 8f17cfd commit 70cae32

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

model_utils.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,21 @@ def get_train_op(FLAGS, total_loss, grads_and_vars=None):
122122
learning_rate = tf.where(global_step < FLAGS.warmup_steps,
123123
warmup_lr, decay_lr)
124124

125+
if (FLAGS.weight_decay > 0 and not FLAGS.use_tpu and
126+
FLAGS.num_core_per_host > 1):
127+
raise ValueError("Do not support `weight_decay > 0` with multi-gpu "
128+
"training so far.")
129+
125130
if FLAGS.weight_decay == 0:
126131
optimizer = tf.train.AdamOptimizer(
127132
learning_rate=learning_rate,
128133
epsilon=FLAGS.adam_epsilon)
129-
elif FLAGS.weight_decay > 0 and FLAGS.num_core_per_host == 1:
134+
else:
130135
optimizer = AdamWeightDecayOptimizer(
131136
learning_rate=learning_rate,
132137
epsilon=FLAGS.adam_epsilon,
133138
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
134139
weight_decay_rate=FLAGS.weight_decay)
135-
else:
136-
raise ValueError("Do not support `weight_decay > 0` with multi-gpu "
137-
"training so far.")
138140

139141
if FLAGS.use_tpu:
140142
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

train_gpu.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@
3939
help="checkpoint path for initializing the model.")
4040

4141
# Optimization config
42-
flags.DEFINE_float("learning_rate", default=2.5e-4,
42+
flags.DEFINE_float("learning_rate", default=1e-4,
4343
help="Maximum learning rate.")
44-
flags.DEFINE_float("clip", default=0.25,
44+
flags.DEFINE_float("clip", default=1.0,
4545
help="Gradient clipping value.")
4646
# for cosine decay
47-
flags.DEFINE_float("min_lr_ratio", default=0.004,
47+
flags.DEFINE_float("min_lr_ratio", default=0.001,
4848
help="Minimum ratio learning rate.")
4949
flags.DEFINE_integer("warmup_steps", default=0,
5050
help="Number of steps for linear lr warmup.")
@@ -56,13 +56,13 @@
5656
help="weight decay")
5757

5858
# Training config
59-
flags.DEFINE_integer("train_batch_size", default=60,
59+
flags.DEFINE_integer("train_batch_size", default=16,
6060
help="Size of train batch.")
6161
flags.DEFINE_integer("train_steps", default=100000,
6262
help="Total number of training steps.")
63-
flags.DEFINE_integer("iterations", default=500,
63+
flags.DEFINE_integer("iterations", default=1000,
6464
help="Number of iterations per repeat loop.")
65-
flags.DEFINE_integer("save_steps", default=10000,
65+
flags.DEFINE_integer("save_steps", default=None,
6666
help="number of steps for model checkpointing.")
6767

6868
# Data config
@@ -73,7 +73,7 @@
7373
"Could be half of seq_len")
7474
flags.DEFINE_bool("bi_data", default=True,
7575
help="Use bidirectional data streams, i.e., forward & backward.")
76-
flags.DEFINE_integer("mask_alpha", default=2,
76+
flags.DEFINE_integer("mask_alpha", default=6,
7777
help="How many tokens to form a group.")
7878
flags.DEFINE_integer("mask_beta", default=1,
7979
help="How many tokens to mask within each group.")
@@ -86,7 +86,7 @@
8686
flags.DEFINE_integer("n_token", 32000, help="Vocab size")
8787

8888
# Model config
89-
flags.DEFINE_integer("mem_len", default=70,
89+
flags.DEFINE_integer("mem_len", default=0,
9090
help="Number of steps to cache")
9191
flags.DEFINE_bool("same_length", default=False,
9292
help="Same length attention")
@@ -95,23 +95,23 @@
9595

9696
flags.DEFINE_integer("n_layer", default=6,
9797
help="Number of layers.")
98-
flags.DEFINE_integer("d_model", default=500,
98+
flags.DEFINE_integer("d_model", default=32,
9999
help="Dimension of the model.")
100-
flags.DEFINE_integer("d_embed", default=500,
100+
flags.DEFINE_integer("d_embed", default=32,
101101
help="Dimension of the embeddings.")
102-
flags.DEFINE_integer("n_head", default=10,
102+
flags.DEFINE_integer("n_head", default=4,
103103
help="Number of attention heads.")
104-
flags.DEFINE_integer("d_head", default=50,
104+
flags.DEFINE_integer("d_head", default=8,
105105
help="Dimension of each attention head.")
106-
flags.DEFINE_integer("d_inner", default=1000,
106+
flags.DEFINE_integer("d_inner", default=32,
107107
help="Dimension of inner hidden size in positionwise feed-forward.")
108-
flags.DEFINE_float("dropout", default=0.1,
108+
flags.DEFINE_float("dropout", default=0.0,
109109
help="Dropout rate.")
110-
flags.DEFINE_float("dropatt", default=0.1,
110+
flags.DEFINE_float("dropatt", default=0.0,
111111
help="Attention dropout rate.")
112112
flags.DEFINE_bool("untie_r", default=False,
113113
help="Untie r_w_bias and r_r_bias")
114-
flags.DEFINE_string("summary_type", default="attn",
114+
flags.DEFINE_string("summary_type", default="last",
115115
help="Method used to summarize a sequence into a compact vector.")
116116
flags.DEFINE_string("ff_activation", default="relu",
117117
help="Activation type used in position-wise feed-forward.")

0 commit comments

Comments
 (0)