def train(self, loss, global_step):
"""
Return a training step for the tensorflow graph
Args:
loss : loss to do sgd on
global_step : which step are we at
"""
opt = tf.train.AdamOptimizer(self.learning_rate)
grads = opt.compute_gradients(loss)
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
variable_averages = tf.train.ExponentialMovingAverage(self.moving_avg_decay, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
train_op = tf.no_op(name='train')
return train_op
model.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录