trainer.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:rarepepes 作者: kendricktan 项目源码 文件源码
def test(self, loader, e):
        self.dis.eval()
        self.gen.eval()

        topilimg = transforms.ToPILImage()

        if not os.path.exists('visualize/'):
            os.makedirs('visualize/')

        idx = random.randint(0, len(loader) - 1)
        _features = loader.dataset[idx]

        orig_x = Variable(self.cudafy(_features[0]))
        orig_y = Variable(self.cudafy(_features[1]))

        orig_x = orig_x.view(1, orig_x.size(0), orig_x.size(1), orig_x.size(2))
        orig_y = orig_y.view(1, orig_y.size(0), orig_y.size(1), orig_x.size(3))

        gen_y = self.gen(orig_x)

        if self.cuda:
            orig_x_np = normalize(orig_x.squeeze().cpu().data, 0, 1)
            orig_y_np = normalize(orig_y.squeeze().cpu().data, 0, 1)
            gen_y_np = normalize(gen_y.squeeze().cpu().data, 0, 1)

        else:
            orig_x_np = normalize(orig_x.squeeze().data, 0, 1)
            orig_y_np = normalize(orig_y.squeeze().data, 0, 1)
            gen_y_np = normalize(gen_y.squeeze().data, 0, 1)

        orig_x_np = topilimg(orig_x_np)
        orig_y_np = topilimg(orig_y_np)
        gen_y_np = topilimg(gen_y_np)

        f, (ax1, ax2, ax3) = plt.subplots(
            3, 1, sharey='row'
        )

        ax1.imshow(orig_x_np)
        ax1.set_title('x')

        ax2.imshow(orig_y_np)
        ax2.set_title('target y')

        ax3.imshow(gen_y_np)
        ax3.set_title('generated y')

        f.savefig('visualize/{}.png'.format(e))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号