def sample(self, n):
"""
Sample n elements uniformly from the memory
"""
indices = np.random.choice(self.cur_size, n, replace=False)
s1 = np.take(self.S1, indices, axis=0)
a = np.take(self.A, indices)
r = np.take(self.R, indices)
s2 = np.take(self.S2, indices, axis=0)
t = np.take(self.T, indices)
return s1, a, r, s2, t
# sample_elements = []
# for _ in range(n):
# sample_elements.append(self.memory[random.randint(0, len(self.memory)-1)])
#
# return sample_elements
评论列表
文章目录