kfac.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:baselines 作者: openai 项目源码 文件源码
def apply_gradients(self, grads):
        coldOptim = tf.train.MomentumOptimizer(
            self._cold_lr, self._momentum)

        def coldSGDstart():
            sgd_grads, sgd_var = zip(*grads)

            if self.max_grad_norm != None:
                sgd_grads, sgd_grad_norm = tf.clip_by_global_norm(sgd_grads,self.max_grad_norm)

            sgd_grads = list(zip(sgd_grads,sgd_var))

            sgd_step_op = tf.assign_add(self.sgd_step, 1)
            coldOptim_op = coldOptim.apply_gradients(sgd_grads)
            if KFAC_DEBUG:
                with tf.control_dependencies([sgd_step_op, coldOptim_op]):
                    sgd_step_op = tf.Print(
                        sgd_step_op, [self.sgd_step, tf.convert_to_tensor('doing cold sgd step')])
            return tf.group(*[sgd_step_op, coldOptim_op])

        kfacOptim_op, qr = self.apply_gradients_kfac(grads)

        def warmKFACstart():
            return kfacOptim_op

        return tf.cond(tf.greater(self.sgd_step, self._cold_iter), warmKFACstart, coldSGDstart), qr
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号