es.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:rl_algorithms 作者: DanielTakeshi 项目源码 文件源码
def test(self, just_one=True):
        """ This is for test-time evaluation. No training is done here. By
        default, iterate through every snapshot.  If `just_one` is true, this
        only runs one set of weights, to ensure that we record right away since
        OpenAI will only record subsets and less frequently.  Changing the loop
        over snapshots is also needed.
        """
        os.makedirs(self.args.directory+'/videos')
        self.env = wrappers.Monitor(self.env, self.args.directory+'/videos', force=True)

        headdir = self.args.directory+'/snapshots/'
        snapshots = os.listdir(headdir)
        snapshots.sort()
        num_rollouts = 10
        if just_one:
            num_rollouts = 1

        for sn in snapshots:
            print("\n***** Currently on snapshot {} *****".format(sn))

            ### Add your own criteria here.
            # if "800" not in sn:
            #     continue
            ###

            with open(headdir+sn, 'rb') as f:
                weights = pickle.load(f)
            self.sess.run(self.set_params_op, 
                          feed_dict={self.new_weights_v: weights})
            returns = []
            for i in range(num_rollouts):
                returns.append( self._compute_return(test=True) )
            print("mean: \t{}".format(np.mean(returns)))
            print("std: \t{}".format(np.std(returns)))
            print("max: \t{}".format(np.max(returns)))
            print("min: \t{}".format(np.min(returns)))
            print("returns:\n{}".format(returns))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号