def main(_):
pp.pprint(flags.FLAGS.__flags)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
with tf.Session() as sess:
dcgan = DCGAN(sess,
dataset=FLAGS.dataset,
batch_size=FLAGS.batch_size,
output_size=FLAGS.output_size,
c_dim=FLAGS.c_dim,
z_dim=FLAGS.z_dim)
if FLAGS.is_train:
if FLAGS.preload_data == True:
data = get_data_arr(FLAGS)
else:
data = glob(os.path.join('./data', FLAGS.dataset, '*.jpg'))
train.train_wasserstein(sess, dcgan, data, FLAGS)
else:
dcgan.load(FLAGS.checkpoint_dir)
评论列表
文章目录