pix2pix.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:pix2pix-pytorch 作者: 1zb 项目源码 文件源码
def test(epoch):
    avg_psnr = 0
    avg_ssim = 0
    for left, right in testing_data_loader:

        if args.direction == 'lr':
            input.data.resize_(left.size()).copy_(left)
            target.data.resize_(right.size()).copy_(right)
        else:
            input.data.resize_(right.size()).copy_(right)
            target.data.resize_(left.size()).copy_(left)

        prediction = netG(input)

        im_true = np.transpose(target.data.cpu().numpy(), (0, 2, 3, 1))
        im_test = np.transpose(prediction.data.cpu().numpy(), (0, 2, 3, 1))
        for i in range(input.size(0)):
            avg_psnr += psnr(im_true[i], im_test[i])
            avg_ssim += (ssim(im_true[i,:,:,0], im_test[i,:,:,0]) + ssim(im_true[i,:,:,1], im_test[i,:,:,1]) + ssim(im_true[i,:,:,2], im_test[i,:,:,2])) / 3
    print("[TEST]  PSNR: {:.4f}; SSIM: {:.4f}".format(avg_psnr / len(test_set), avg_ssim / len(test_set)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号