def query(self, elements):
if self.capacity == 0:
return elements
choices = []
for element in elements.data:
element = torch.unsqueeze(element, 0)
if self.size < self.capacity:
self.size += 1
self.elements.append(element)
choices.append(element)
else:
if random.uniform(0, 1) > 0.5:
index = random.randint(0, self.capacity - 1)
candidate = self.elements[index].clone()
self.elements[index] = element
choices.append(candidate)
else:
choices.append(element)
choices = Variable(torch.cat(choices, 0))
return choices
评论列表
文章目录