def train_simple(total_loss, global_step):
with tf.variable_scope('train_op'):
# Variables that affect learning rate.
num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
# Decay the learning rate exponentially based on the number of steps.
lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
global_step,
decay_steps,
LEARNING_RATE_DECAY_FACTOR,
staircase=True)
tf.summary.scalar('learning_rate', lr)
# update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# with tf.control_dependencies(update_ops):
# opt = tf.train.MomentumOptimizer(lr, 0.9).minimize(total_loss, global_step=global_step)
opt = tf.train.AdamOptimizer(lr).minimize(total_loss, global_step=global_step)
tf.summary.scalar(total_loss.op.name + ' (raw)', total_loss)
return opt, lr
评论列表
文章目录