HiPMDP.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:hip-mdp-public 作者: dtak 项目源码 文件源码
def __compute_bnn_training_error(self):
        """Compute BNN training error on most recent episode."""
        exp = np.reshape(self.episode_buffer_bnn, (len(self.episode_buffer_bnn),-1))
        episode_X = np.array([np.hstack([exp[tt,0],exp[tt,1]]) for tt in xrange(exp.shape[0])])
        episode_Y = np.array([exp[tt,3] for tt in xrange(exp.shape[0])])
        if self.state_diffs:
            # subtract previous state
            episode_Y -= episode_X[:,:self.num_dims]
        l2_errors = self.network.get_td_error(np.hstack([episode_X, np.tile(self.weight_set, (episode_X.shape[0],1))]), episode_Y, 0.0, 1.0)
        self.mean_episode_errors[self.instance_iter,self.episode_iter] = np.mean(l2_errors)
        self.std_episode_errors[self.instance_iter,self.episode_iter] = np.std(l2_errors)
        if self.print_output:
            print('BNN Error: {}'.format(self.mean_episode_errors[self.instance_iter,self.episode_iter]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号