def main(_):
pp.pprint(flags.FLAGS.__flags)
if FLAGS.input_width is None:
FLAGS.input_width = FLAGS.input_height
if FLAGS.output_width is None:
FLAGS.output_width = FLAGS.output_height
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth=True
with tf.Session(config=run_config) as sess:
wgan = WGAN(
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
input_water_width=FLAGS.input_water_width,
input_water_height=FLAGS.input_water_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
c_dim=FLAGS.c_dim,
max_depth = FLAGS.max_depth,
save_epoch=FLAGS.save_epoch,
water_dataset_name=FLAGS.water_dataset,
air_dataset_name = FLAGS.air_dataset,
depth_dataset_name = FLAGS.depth_dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
is_crop=FLAGS.is_crop,
checkpoint_dir=FLAGS.checkpoint_dir,
results_dir = FLAGS.results_dir,
sample_dir=FLAGS.sample_dir,
num_samples = FLAGS.num_samples)
if FLAGS.is_train:
wgan.train(FLAGS)
else:
if not wgan.load(FLAGS.checkpoint_dir):
raise Exception("[!] Train a model first, then run test mode")
wgan.test(FLAGS)
# to_json("./web/js/layers.js", [wgan.h0_w, wgan.h0_b, wgan.g_bn0],
# [wgan.h1_w, wgan.h1_b, wgan.g_bn1],
# [wgan.h2_w, wgan.h2_b, wgan.g_bn2],
# [wgan.h3_w, wgan.h3_b, wgan.g_bn3],
# [wgan.h4_w, wgan.h4_b, None])
# Below is codes for visualization
#OPTION = 1
#visualize(sess, wgan, FLAGS, OPTION)
评论列表
文章目录