kfac.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:acktr 作者: emansim 项目源码 文件源码
def apply_gradients(self, grads):
        coldOptim = tf.train.MomentumOptimizer(
            self._cold_lr * (1. - self._momentum), self._momentum)

        def coldSGDstart():
            sgd_step_op = tf.assign_add(self.sgd_step, 1)
            coldOptim_op = coldOptim.apply_gradients(grads)
            if KFAC_DEBUG:
                with tf.control_dependencies([sgd_step_op, coldOptim_op]):
                    sgd_step_op = tf.Print(
                        sgd_step_op, [self.sgd_step, tf.convert_to_tensor('doing cold sgd step')])
            return tf.group(*[sgd_step_op, coldOptim_op])

        kfacOptim_op, qr = self.apply_gradients_kfac(grads)

        def warmKFACstart():
            return kfacOptim_op

        return tf.cond(tf.greater(self.sgd_step, self._cold_iter), warmKFACstart, coldSGDstart), qr
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号