mnist_svae.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:tf_practice 作者: juho-lee 项目源码 文件源码
def test():
    saver.restore(sess, FLAGS.save_dir+'/model.ckpt')
    batch_x, _ = mnist.test.next_batch(batch_size)
    fig = plt.figure('original')
    plt.gray()
    plt.axis('off')
    plt.imshow(batchmat_to_tileimg(batch_x, (height, width), (10, 10)))
    fig.savefig(FLAGS.save_dir+'/original.png')

    fig = plt.figure('reconstructed')
    plt.gray()
    plt.axis('off')
    p_recon = sess.run(p, {x:batch_x})
    plt.imshow(batchmat_to_tileimg(p_recon, (height, width), (10, 10)))
    fig.savefig(FLAGS.save_dir+'/reconstructed.png')

    batch_w = np.zeros((n_fac*n_fac, n_fac))
    for i in range(n_fac):
        batch_w[i*n_fac:(i+1)*n_fac, i] = 1.0
    batch_z = np.random.normal(size=(n_fac*n_fac, n_lat))
    p_gen = sess.run(p, {w:batch_w, z:batch_z})
    I_gen = batchmat_to_tileimg(p_gen, (height, width), (n_fac, n_fac))
    fig = plt.figure('generated')
    plt.gray()
    plt.axis('off')
    plt.imshow(I_gen)
    fig.savefig(FLAGS.save_dir+'/generated.png')

    fig = plt.figure('factor activation heatmap')
    hist = np.zeros((10, n_fac))
    for i in range(mnist.test.num_examples):
        batch_x, batch_y = mnist.test.next_batch(batch_size)
        batch_w = sess.run(w, {x:batch_x})
        for i in range(batch_size):
            hist[batch_y[i], batch_w[i] > 0] += 1
    sns.heatmap(hist)
    fig.savefig(FLAGS.save_dir+'/feature_activation.png')

    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号