train_imagenet.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:streetview 作者: ydnaandy123 项目源码 文件源码
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session(config=tf.ConfigProto(
              allow_soft_placement=True, log_device_placement=False)) as sess:
        if FLAGS.dataset == 'mnist':
            assert False
        dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,
                    sample_size = 64,
                    z_dim = 8192,
                    d_label_smooth = .25,
                    generator_target_prob = .75 / 2.,
                    out_stddev = .075,
                    out_init_b = - .45,
                    image_shape=[FLAGS.image_width, FLAGS.image_width, 3],
                    dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir,
                    sample_dir=FLAGS.sample_dir,
                    generator=Generator(),
                    train_func=train, discriminator_func=discriminator,
                    build_model_func=build_model, config=FLAGS,
                    devices=["gpu:0", "gpu:1", "gpu:2", "gpu:3"] #, "gpu:4"]
                    )

        if FLAGS.is_train:
            print "TRAINING"
            dcgan.train(FLAGS)
            print "DONE TRAINING"
        else:
            dcgan.load(FLAGS.checkpoint_dir)

        OPTION = 2
        visualize(sess, dcgan, FLAGS, OPTION)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号