mainjamaica.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:WaterGAN 作者: kskin 项目源码 文件源码
def main(_):
  pp.pprint(flags.FLAGS.__flags)

  if FLAGS.input_width is None:
    FLAGS.input_width = FLAGS.input_height
  if FLAGS.output_width is None:
    FLAGS.output_width = FLAGS.output_height

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)

  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True
  with tf.Session(config=run_config) as sess:
    wgan = WGAN(
      sess,
      input_width=FLAGS.input_width,
      input_height=FLAGS.input_height,
      input_water_width=FLAGS.input_water_width,
      input_water_height=FLAGS.input_water_height,
      output_width=FLAGS.output_width,
      output_height=FLAGS.output_height,
      batch_size=FLAGS.batch_size,
      c_dim=FLAGS.c_dim,
      max_depth = FLAGS.max_depth,
      save_epoch=FLAGS.save_epoch,
      water_dataset_name=FLAGS.water_dataset,
      air_dataset_name = FLAGS.air_dataset,
      depth_dataset_name = FLAGS.depth_dataset,
      input_fname_pattern=FLAGS.input_fname_pattern,
      is_crop=FLAGS.is_crop,
      checkpoint_dir=FLAGS.checkpoint_dir,
      results_dir = FLAGS.results_dir,
      sample_dir=FLAGS.sample_dir,
      num_samples = FLAGS.num_samples)

    if FLAGS.is_train:
      wgan.train(FLAGS)
    else:
      if not wgan.load(FLAGS.checkpoint_dir):
        raise Exception("[!] Train a model first, then run test mode")
      wgan.test(FLAGS)

    # to_json("./web/js/layers.js", [wgan.h0_w, wgan.h0_b, wgan.g_bn0],
    #                 [wgan.h1_w, wgan.h1_b, wgan.g_bn1],
    #                 [wgan.h2_w, wgan.h2_b, wgan.g_bn2],
    #                 [wgan.h3_w, wgan.h3_b, wgan.g_bn3],
    #                 [wgan.h4_w, wgan.h4_b, None])

    # Below is codes for visualization
    #OPTION = 1
    #visualize(sess, wgan, FLAGS, OPTION)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号