main.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:WassersteinGAN-TensorFlow 作者: MustafaMustafa 项目源码 文件源码
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        dcgan = DCGAN(sess, 
                      dataset=FLAGS.dataset,
                      batch_size=FLAGS.batch_size,
                      output_size=FLAGS.output_size,
                      c_dim=FLAGS.c_dim,
                      z_dim=FLAGS.z_dim)

        if FLAGS.is_train:
            if FLAGS.preload_data == True:
                data = get_data_arr(FLAGS)
            else:
                data = glob(os.path.join('./data', FLAGS.dataset, '*.jpg'))
            train.train_wasserstein(sess, dcgan, data, FLAGS)
        else:
            dcgan.load(FLAGS.checkpoint_dir)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号