def main(_):
pp.pprint(flags.FLAGS.__flags)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
with tf.Session(config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=False)) as sess:
if FLAGS.dataset == 'mnist':
assert False
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,
sample_size = 64,
z_dim = 8192,
d_label_smooth = .25,
generator_target_prob = .75 / 2.,
out_stddev = .075,
out_init_b = - .45,
image_shape=[FLAGS.image_width, FLAGS.image_width, 3],
dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
generator=Generator(),
train_func=train, discriminator_func=discriminator,
build_model_func=build_model, config=FLAGS,
devices=["gpu:0", "gpu:1", "gpu:2", "gpu:3"] #, "gpu:4"]
)
if FLAGS.is_train:
print "TRAINING"
dcgan.train(FLAGS)
print "DONE TRAINING"
else:
dcgan.load(FLAGS.checkpoint_dir)
OPTION = 2
visualize(sess, dcgan, FLAGS, OPTION)
评论列表
文章目录