def main(_):
pp.pprint(flags.FLAGS.__flags)
sample_dir_ = os.path.join(FLAGS.sample_dir, FLAGS.name)
checkpoint_dir_ = os.path.join(FLAGS.checkpoint_dir, FLAGS.name)
log_dir_ = os.path.join(FLAGS.log_dir, FLAGS.name)
if not os.path.exists(checkpoint_dir_):
os.makedirs(checkpoint_dir_)
if not os.path.exists(sample_dir_):
os.makedirs(sample_dir_)
if not os.path.exists(log_dir_):
os.makedirs(log_dir_)
with tf.Session() as sess:
if FLAGS.dataset == 'mnist':
dcgan = DCGAN(sess, config=FLAGS, batch_size=FLAGS.batch_size, output_size=28, c_dim=1,
dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=checkpoint_dir_, sample_dir=sample_dir_, log_dir=log_dir_)
else:
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, output_size=FLAGS.output_size, c_dim=FLAGS.c_dim,
dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir)
if FLAGS.is_train:
dcgan.train(FLAGS)
else:
dcgan.sampling(FLAGS)
#dcgan.load(FLAGS.checkpoint_dir)
if FLAGS.visualize:
to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
[dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
[dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
[dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
[dcgan.h4_w, dcgan.h4_b, None])
# Below is codes for visualization
OPTION = 2
visualize(sess, dcgan, FLAGS, OPTION)
评论列表
文章目录