def main(_):
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
# Do not take all memory
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.80)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
if FLAGS.dataset == 'cityscapes':
print('Select CITYSCAPES')
mask_dir = CITYSCAPES_mask_dir
syn_dir = CITYSCAPES_syn_dir
FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 192, 512, False
FLAGS.dataset_dir = CITYSCAPES_dir
elif FLAGS.dataset == 'inria':
print('Select INRIAPerson')
FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 160, 96, False
FLAGS.dataset_dir = INRIA_dir
elif FLAGS.dataset == 'indoor':
print('Select indoor')
syn_dir = CITYSCAPES_syn_dir
FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 256, 256, False
FLAGS.dataset_dir = indoor_dir
discriminator = Discriminator(sess, batch_size=FLAGS.batch_size, output_size_h=FLAGS.output_size_h, output_size_w=FLAGS.output_size_w, c_dim=FLAGS.c_dim,
dataset_name=FLAGS.dataset, checkpoint_dir=FLAGS.checkpoint_dir, dataset_dir=FLAGS.dataset_dir)
if FLAGS.mode == 'test':
print('Testing!')
discriminator.test(FLAGS, syn_dir)
elif FLAGS.mode == 'train':
print('Train!')
discriminator.train(FLAGS, syn_dir)
elif FLAGS.mode == 'complete':
print('Complete!')
评论列表
文章目录