def _step(self, action):
obs, reward, done, info = self.env.step(action)
location = info.get('location')
if location is not None:
"""
self.locations.append(location)
if len(self.locations) == self.buffer_size:
# rebuild the kde
self.kde = stats.gaussian_kde(np.array(self.locations).T, self.bandwidth)
# plot it?
dims = obs.shape[:2]
grid = np.indices(dims)
kde = self.kde.logpdf(grid.reshape([2, -1]))
kde = kde.reshape(dims)
info['kde'] = kde
#plt.imsave('test.png', kde)
# drop the older locations
self.locations = self.locations[self.buffer_size//2:]
#plt.imsave('counts.png', self.counts)
#info['logprob'] = logprob
if self.kde:
logpdf = self.kde.logpdf(np.array(location))
info['logpdf'] = logpdf
reward -= logpdf
"""
location = location + self.breadth # padding
index = tuple(location.tolist())
patch = extract_patch(self.counts, index, self.breadth)
count = (self.kernel * patch).sum()
info['log/visits'] = count
logprob = np.log(count / self.total)
info['log/visit_logprob'] = logprob
#reward = 0
bonus = self.explore_scale * (self.logprob - logprob)
info['log/explore_bonus'] = np.abs(bonus)
reward += bonus
self.logprob = logprob
if self.decay:
self.counts *= self.decay
else:
self.total += 1
self.counts[index] += 1
return obs, reward, done, info
评论列表
文章目录