def main(_):
pp.pprint(FLAGS.__flags)
# training/inference
with tf.Session() as sess:
dcgan = DCGAN(sess, FLAGS)
# path checks
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(os.path.join(FLAGS.log_dir, dcgan.get_model_dir())):
os.makedirs(os.path.join(FLAGS.log_dir, dcgan.get_model_dir()))
if not os.path.exists(os.path.join(FLAGS.sample_dir, dcgan.get_model_dir())):
os.makedirs(os.path.join(FLAGS.sample_dir, dcgan.get_model_dir()))
# load checkpoint if found
if dcgan.checkpoint_exists():
print("Loading checkpoints...")
if dcgan.load():
print "success!"
else:
raise IOError("Could not read checkpoints from {0}!".format(
FLAGS.checkpoint_dir))
else:
if not FLAGS.train:
raise IOError("No checkpoints found but need for sampling!")
print "No checkpoints found. Training from scratch."
dcgan.load()
# train DCGAN
if FLAGS.train:
train(dcgan)
# inference/visualization code goes here
print "Generating samples..."
inference.sample_images(dcgan)
print "Generating visualizations of z..."
inference.visualize_z(dcgan)
评论列表
文章目录