def test(self):
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
self.saver_z.restore(sess, self.encode_z_model)
self.saver_y.restore(sess, self.encode_y_model)
realbatch_array, _ = MnistData.getNextBatch(self.ds_train, self.label_y, 0, 50,
self.batch_size)
output_image , label_y = sess.run([self.fake_images,self.e_y], feed_dict={self.images: realbatch_array})
#one-hot
#label_y = tf.arg_max(label_y, 1)
print label_y
save_images(output_image , [8 , 8] , './{}/test{:02d}_{:04d}.png'.format(self.sample_path , 0, 0))
save_images(realbatch_array , [8 , 8] , './{}/test{:02d}_{:04d}_r.png'.format(self.sample_path , 0, 0))
gen_img = cv2.imread('./{}/test{:02d}_{:04d}.png'.format(self.sample_path , 0, 0), 0)
real_img = cv2.imread('./{}/test{:02d}_{:04d}_r.png'.format(self.sample_path , 0, 0), 0)
cv2.imshow("test_EGan", gen_img)
cv2.imshow("Real_Image", real_img)
cv2.waitKey(-1)
print("Test finish!")
评论列表
文章目录