train_predict_z.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:iGAN 作者: junyanz 项目源码 文件源码
def rec_test(test_data, n_epochs=0, batch_size=128, output_dir=None):

    print('computing reconstruction loss on test images')
    rec_imgs = []
    imgs = []
    costs = []
    ntest = len(test_data)

    for n in tqdm(range(ntest / batch_size)):
        imb = test_data[n*batch_size:(n+1)*batch_size, ...]
        # imb = train_dcgan_utils.transform(xmb, nc=3)
        [cost, gx] = _train_p_cost(imb)
        costs.append(cost)
        ntest = ntest + 1
        if n == 0:
            utils.print_numpy(imb)
            utils.print_numpy(gx)
            imgs.append(train_dcgan_utils.inverse_transform(imb, npx=npx, nc=nc))
            rec_imgs.append(train_dcgan_utils.inverse_transform(gx, npx=npx, nc=nc))

    if output_dir is not None:
        # st()
        save_samples = np.hstack(np.concatenate(imgs, axis=0))
        save_recs = np.hstack(np.concatenate(rec_imgs, axis=0))
        save_comp = np.vstack([save_samples, save_recs])
        mean_cost = np.mean(costs)

        txt = 'epoch = %3.3d, cost = %3.3f' % (n_epochs, mean_cost)

        width = save_comp.shape[1]
        save_f = (save_comp*255).astype(np.uint8)
        html.save_image([save_f], [''], header=txt, width=width, cvt=True)
        html.save()
        save_cvt = cv2.cvtColor(save_f, cv2.COLOR_RGB2BGR)
        cv2.imwrite(os.path.join(rec_dir, 'rec_epoch_%5.5d.png'%n_epochs), save_cvt)

    return mean_cost
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号