ca_network.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:a3c 作者: dm-mch 项目源码 文件源码
def _build_gradient(self, target):
        """
        Local gradient for remote vars

        """
        local_grad = tf.gradients(self.total_loss, self.get_trainable_weights())
        self.for_summary_scalar += [tf.global_norm(local_grad, name='grad_norm'),
                                    tf.global_norm(self.get_trainable_weights(), name='vars_norm')]
        # clip grad by norm
        local_grad, _ = tf.clip_by_global_norm(local_grad, self.clip_grad_norm)

        # mix with remote vars
        remote_vars = target.get_trainable_weights()
        assert len(local_grad) == len(remote_vars)
        vars_and_grads = list(zip(local_grad, remote_vars))

        # each worker has a different set of adam optimizer parameters
        optimizer = tf.train.AdamOptimizer(self.lr)

        # apply
        apply_grad = optimizer.apply_gradients(vars_and_grads)
        inc_step = self.global_step.assign_add(tf.shape(self.x)[0])
        self.train_op = tf.group(apply_grad, inc_step)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号