def training_graph(loss, learning_rate=1.0, max_grad_norm=5.0):
''' Builds training graph. '''
global_step = tf.Variable(0, name='global_step', trainable=False)
with tf.variable_scope('SGD_Training'):
# SGD learning parameter
learning_rate = tf.Variable(learning_rate, trainable=False, name='learning_rate')
# collect all trainable variables
tvars = tf.trainable_variables()
grads, global_norm = tf.clip_by_global_norm(tf.gradients(loss, tvars), max_grad_norm)
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step)
return adict(
learning_rate=learning_rate,
global_step=global_step,
global_norm=global_norm,
train_op=train_op)
评论列表
文章目录