image_pool.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:GAN_Liveness_Detection 作者: yunfan0621 项目源码 文件源码
def query(self, images):
        # images: torch.Variable of size [batch_size, channel * 2, w, h]

        if self.pool_size == 0:
            return images

        return_images = []
        for image in images.data: # traverse data in batch dimension
            image = torch.unsqueeze(image, 0)

            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                # randomly substitute
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)

        return_images = Variable(torch.cat(return_images, 0))

        return return_images
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号