def sample(self, amount, temperature=1):
priorities = self.buffer.buffers['priority'].value()[:self.size()]
logprobs = tf.log(priorities / tf.reduce_sum(priorities)) / temperature
positions = tf.multinomial(logprobs[None, ...], amount)[0]
return [ tf.gather(b, positions) for key,b in self.buffer.buffers.items() if key != 'priority' ]
评论列表
文章目录