def main(_):
pp.pprint(flags.FLAGS.__flags)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
# Do not take all memory
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.30)
# sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
# w/ y label
if FLAGS.dataset == 'mnist':
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, y_dim=10, output_size=28,
c_dim=1, dataset_name=FLAGS.dataset,
checkpoint_dir=FLAGS.checkpoint_dir)
# w/o y label
else:
if FLAGS.dataset == 'cityscapes':
print 'Select CITYSCAPES'
mask_dir = CITYSCAPES_mask_dir
syn_dir = CITYSCAPES_syn_dir_2
FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 192, 512, False
FLAGS.dataset_dir = CITYSCAPES_dir
elif FLAGS.dataset == 'inria':
print 'Select INRIAPerson'
FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 160, 96, False
FLAGS.dataset_dir = INRIA_dir
discriminator = Discriminator(sess, batch_size=FLAGS.batch_size, output_size_h=FLAGS.output_size_h, output_size_w=FLAGS.output_size_w, c_dim=FLAGS.c_dim,
dataset_name=FLAGS.dataset,
checkpoint_dir=FLAGS.checkpoint_dir, dataset_dir=FLAGS.dataset_dir)
if FLAGS.mode == 'test':
print('Testing!')
discriminator.test(FLAGS, syn_dir)
elif FLAGS.mode == 'train':
print('Train!')
discriminator.train(FLAGS, syn_dir)
elif FLAGS.mode == 'complete':
print('Complete!')
synthesize_discriminator.py 文件源码
python
阅读 35
收藏 0
点赞 0
评论 0
评论列表
文章目录