kfac.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:baselines 作者: openai 项目源码 文件源码
def apply_stats(self, statsUpdates):
        """ compute stats and update/apply the new stats to the running average
        """

        def updateAccumStats():
            if self._full_stats_init:
                return tf.cond(tf.greater(self.sgd_step, self._cold_iter), lambda: tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter)), tf.no_op)
            else:
                return tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter))

        def updateRunningAvgStats(statsUpdates, fac_iter=1):
            # return tf.cond(tf.greater_equal(self.factor_step,
            # tf.convert_to_tensor(fac_iter)), lambda:
            # tf.group(*self._apply_stats(stats_list, varlist)), tf.no_op)
            return tf.group(*self._apply_stats(statsUpdates))

        if self._async_stats:
            # asynchronous stats update
            update_stats = self._apply_stats(statsUpdates)

            queue = tf.FIFOQueue(1, [item.dtype for item in update_stats], shapes=[
                                 item.get_shape() for item in update_stats])
            enqueue_op = queue.enqueue(update_stats)

            def dequeue_stats_op():
                return queue.dequeue()
            self.qr_stats = tf.train.QueueRunner(queue, [enqueue_op])
            update_stats_op = tf.cond(tf.equal(queue.size(), tf.convert_to_tensor(
                0)), tf.no_op, lambda: tf.group(*[dequeue_stats_op(), ]))
        else:
            # synchronous stats update
            update_stats_op = tf.cond(tf.greater_equal(
                self.stats_step, self._stats_accum_iter), lambda: updateRunningAvgStats(statsUpdates), updateAccumStats)
        self._update_stats_op = update_stats_op
        return update_stats_op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号