def test():
saver.restore(sess, FLAGS.save_dir+'/model.ckpt')
batch_x = test_x[0:100]
fig = plt.figure('original')
plt.gray()
plt.axis('off')
plt.imshow(batchmat_to_tileimg(batch_x, (height, width), (10, 10)))
fig.savefig(FLAGS.save_dir+'/original.png')
fig = plt.figure('reconstructed')
plt.gray()
plt.axis('off')
p_recon = sess.run(p, {x:batch_x})
plt.imshow(batchmat_to_tileimg(p_recon, (height, width), (10, 10)))
fig.savefig(FLAGS.save_dir+'/reconstructed.png')
batch_w = np.zeros((n_fac*n_fac, n_fac))
for i in range(n_fac):
batch_w[i*n_fac:(i+1)*n_fac, i] = 1.0
batch_z = np.random.normal(size=(n_fac*n_fac, n_lat))
p_gen = sess.run(p, {w:batch_w, z:batch_z})
I_gen = batchmat_to_tileimg(p_gen, (height, width), (n_fac, n_fac))
fig = plt.figure('generated')
plt.gray()
plt.axis('off')
plt.imshow(I_gen)
fig.savefig(FLAGS.save_dir+'/generated.png')
"""
fig = plt.figure('factor activation heatmap')
hist = np.zeros((10, n_fac))
for i in range(len(test_x)):
batch_x = test_x[i*batch_size:(i+1)*batch_size]
batch_w = sess.run(w, {x:batch_x})
for i in range(batch_size):
hist[batch_y[i], batch_w[i] > 0] += 1
sns.heatmap(hist)
fig.savefig(FLAGS.save_dir+'/feature_activation.png')
"""
plt.show()
评论列表
文章目录