def test(self, loader, e):
self.dis.eval()
self.gen.eval()
topilimg = transforms.ToPILImage()
if not os.path.exists('visualize/'):
os.makedirs('visualize/')
idx = random.randint(0, len(loader) - 1)
_features = loader.dataset[idx]
orig_x = Variable(self.cudafy(_features[0]))
orig_y = Variable(self.cudafy(_features[1]))
orig_x = orig_x.view(1, orig_x.size(0), orig_x.size(1), orig_x.size(2))
orig_y = orig_y.view(1, orig_y.size(0), orig_y.size(1), orig_x.size(3))
gen_y = self.gen(orig_x)
if self.cuda:
orig_x_np = normalize(orig_x.squeeze().cpu().data, 0, 1)
orig_y_np = normalize(orig_y.squeeze().cpu().data, 0, 1)
gen_y_np = normalize(gen_y.squeeze().cpu().data, 0, 1)
else:
orig_x_np = normalize(orig_x.squeeze().data, 0, 1)
orig_y_np = normalize(orig_y.squeeze().data, 0, 1)
gen_y_np = normalize(gen_y.squeeze().data, 0, 1)
orig_x_np = topilimg(orig_x_np)
orig_y_np = topilimg(orig_y_np)
gen_y_np = topilimg(gen_y_np)
f, (ax1, ax2, ax3) = plt.subplots(
3, 1, sharey='row'
)
ax1.imshow(orig_x_np)
ax1.set_title('x')
ax2.imshow(orig_y_np)
ax2.set_title('target y')
ax3.imshow(gen_y_np)
ax3.set_title('generated y')
f.savefig('visualize/{}.png'.format(e))
评论列表
文章目录