main.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:DeepLearning 作者: Wanwannodao 项目源码 文件源码
def main(_):
    loader = Loader(FLAGS.data_dir, FLAGS.data, FLAGS.batch_size)
    print("# of data: {}".format(loader.data_num))
    with tf.Session() as sess:                                
        lsgan = LSGAN([FLAGS.batch_size, 112, 112, 3])
        sess.run(tf.global_variables_initializer())

        for epoch in range(10000):
            loader.reset()

            for step in range(int(loader.batch_num/FLAGS.d)):
                if (step == 0 and epoch % 1 == 100):
                    utils.visualize(sess.run(lsgan.gen_img), epoch)

                for _ in range(FLAGS.d):
                    batch = np.asarray(loader.next_batch(), dtype=np.float32)
                    batch = (batch-127.5) / 127.5
                    #print("{}".format(batch.shape))
                    feed={lsgan.X: batch}
                    _ = sess.run(lsgan.d_train_op, feed_dict=feed)
                        #utils.visualize(batch, (epoch+1)*100)

                #cv2.namedWindow("window")
                #cv2.imshow("window", cv2.cvtColor(batch[0], cv2.COLOR_RGB2BGR))
                #cv2.waitKey(0)
                #cv2.destroyAllWindows()

                _ = sess.run(lsgan.g_train_op)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号