def query(self, images):
if self.pool_size == 0:
return Variable(images)
return_images = []
for image in images:
image = torch.unsqueeze(image, 0)
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5:
random_id = random.randint(0, self.pool_size-1)
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else:
return_images.append(image)
return_images = Variable(torch.cat(return_images, 0))
return return_images
image_pool.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录