main.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:dcgan-tfslim 作者: mqtlam 项目源码 文件源码
def main(_):
    pp.pprint(FLAGS.__flags)

    # training/inference
    with tf.Session() as sess:
        dcgan = DCGAN(sess, FLAGS)

        # path checks
        if not os.path.exists(FLAGS.checkpoint_dir):
            os.makedirs(FLAGS.checkpoint_dir)
        if not os.path.exists(os.path.join(FLAGS.log_dir, dcgan.get_model_dir())):
            os.makedirs(os.path.join(FLAGS.log_dir, dcgan.get_model_dir()))
        if not os.path.exists(os.path.join(FLAGS.sample_dir, dcgan.get_model_dir())):
            os.makedirs(os.path.join(FLAGS.sample_dir, dcgan.get_model_dir()))

        # load checkpoint if found
        if dcgan.checkpoint_exists():
            print("Loading checkpoints...")
            if dcgan.load():
                print "success!"
            else:
                raise IOError("Could not read checkpoints from {0}!".format(
                    FLAGS.checkpoint_dir))
        else:
            if not FLAGS.train:
                raise IOError("No checkpoints found but need for sampling!")
            print "No checkpoints found. Training from scratch."
            dcgan.load()

        # train DCGAN
        if FLAGS.train:
            train(dcgan)

        # inference/visualization code goes here
        print "Generating samples..."
        inference.sample_images(dcgan)
        print "Generating visualizations of z..."
        inference.visualize_z(dcgan)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号