exploration.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:human-rl 作者: gsastry 项目源码 文件源码
def _step(self, action):
        obs, reward, done, info = self.env.step(action)

        location = info.get('location')
        if location is not None:
            """
            self.locations.append(location)
            if len(self.locations) == self.buffer_size:
                # rebuild the kde
                self.kde = stats.gaussian_kde(np.array(self.locations).T, self.bandwidth)

                # plot it?
                dims = obs.shape[:2]
                grid = np.indices(dims)
                kde = self.kde.logpdf(grid.reshape([2, -1]))
                kde = kde.reshape(dims)
                info['kde'] = kde

                #plt.imsave('test.png', kde)

                # drop the older locations
                self.locations = self.locations[self.buffer_size//2:]

            #plt.imsave('counts.png', self.counts)
            #info['logprob'] = logprob

            if self.kde:
                logpdf = self.kde.logpdf(np.array(location))
                info['logpdf'] = logpdf

                reward -= logpdf
            """

            location = location + self.breadth # padding
            index = tuple(location.tolist())

            patch = extract_patch(self.counts, index, self.breadth)
            count = (self.kernel * patch).sum()

            info['log/visits'] = count

            logprob = np.log(count / self.total)
            info['log/visit_logprob'] = logprob

            #reward = 0

            bonus = self.explore_scale * (self.logprob - logprob) 
            info['log/explore_bonus'] = np.abs(bonus)
            reward += bonus

            self.logprob = logprob

            if self.decay:
                self.counts *= self.decay
            else:
                self.total += 1

            self.counts[index] += 1

        return obs, reward, done, info
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号