Gan.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:ICGan-tensorflow 作者: zhangqianhui 项目源码 文件源码
def test(self):

        init = tf.global_variables_initializer()

        with tf.Session() as sess:

            sess.run(init)

            self.saver_z.restore(sess, self.encode_z_model)
            self.saver_y.restore(sess, self.encode_y_model)

            realbatch_array, _ = MnistData.getNextBatch(self.ds_train, self.label_y, 0, 50,
                                                        self.batch_size)

            output_image , label_y = sess.run([self.fake_images,self.e_y], feed_dict={self.images: realbatch_array})

            #one-hot
            #label_y = tf.arg_max(label_y, 1)

            print label_y

            save_images(output_image , [8 , 8] , './{}/test{:02d}_{:04d}.png'.format(self.sample_path , 0, 0))
            save_images(realbatch_array , [8 , 8] , './{}/test{:02d}_{:04d}_r.png'.format(self.sample_path , 0, 0))

            gen_img = cv2.imread('./{}/test{:02d}_{:04d}.png'.format(self.sample_path , 0, 0), 0)
            real_img = cv2.imread('./{}/test{:02d}_{:04d}_r.png'.format(self.sample_path , 0, 0), 0)


            cv2.imshow("test_EGan", gen_img)
            cv2.imshow("Real_Image", real_img)

            cv2.waitKey(-1)

            print("Test finish!")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号