model.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:Saliency_Detection_Convolutional_Autoencoder 作者: arthurmeyer 项目源码 文件源码
def train(self, loss, global_step):
    """
    Return a training step for the tensorflow graph

    Args:
      loss                   : loss to do sgd on
      global_step            : which step are we at
    """

    opt = tf.train.AdamOptimizer(self.learning_rate)
    grads = opt.compute_gradients(loss)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

    variable_averages = tf.train.ExponentialMovingAverage(self.moving_avg_decay, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())

    with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
      train_op = tf.no_op(name='train')

    return train_op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号