def rec_test(test_data, n_epochs=0, batch_size=128, output_dir=None):
print('computing reconstruction loss on test images')
rec_imgs = []
imgs = []
costs = []
ntest = len(test_data)
for n in tqdm(range(ntest / batch_size)):
imb = test_data[n*batch_size:(n+1)*batch_size, ...]
# imb = train_dcgan_utils.transform(xmb, nc=3)
[cost, gx] = _train_p_cost(imb)
costs.append(cost)
ntest = ntest + 1
if n == 0:
utils.print_numpy(imb)
utils.print_numpy(gx)
imgs.append(train_dcgan_utils.inverse_transform(imb, npx=npx, nc=nc))
rec_imgs.append(train_dcgan_utils.inverse_transform(gx, npx=npx, nc=nc))
if output_dir is not None:
# st()
save_samples = np.hstack(np.concatenate(imgs, axis=0))
save_recs = np.hstack(np.concatenate(rec_imgs, axis=0))
save_comp = np.vstack([save_samples, save_recs])
mean_cost = np.mean(costs)
txt = 'epoch = %3.3d, cost = %3.3f' % (n_epochs, mean_cost)
width = save_comp.shape[1]
save_f = (save_comp*255).astype(np.uint8)
html.save_image([save_f], [''], header=txt, width=width, cvt=True)
html.save()
save_cvt = cv2.cvtColor(save_f, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(rec_dir, 'rec_epoch_%5.5d.png'%n_epochs), save_cvt)
return mean_cost
评论列表
文章目录