def main(_):
pp.pprint(flags.FLAGS.__flags)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
with tf.Session() as sess:
dcgan = DCGAN(sess,
dataset=FLAGS.dataset,
batch_size=FLAGS.batch_size,
output_size=FLAGS.output_size,
c_dim=FLAGS.c_dim,
z_dim=FLAGS.z_dim)
if FLAGS.is_train:
if FLAGS.preload_data == True:
data = get_data_arr(FLAGS)
else:
data = glob(os.path.join('./data', FLAGS.dataset, '*.jpg'))
train.train_wasserstein(sess, dcgan, data, FLAGS)
else:
dcgan.load(FLAGS.checkpoint_dir)
python类DCGAN的实例源码
def main(_):
pp.pprint(flags.FLAGS.__flags)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
with tf.Session() as sess:
if FLAGS.dataset == 'mnist':
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, y_dim=10,
dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)
else:
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,
dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)
if FLAGS.is_train:
dcgan.train(FLAGS)
else:
if FLAGS.is_single:
dcgan.single_test(FLAGS.checkpoint_dir, FLAGS.file_name)
elif FLAGS.is_small:
dcgan.batch_test2(FLAGS.checkpoint_dir)
else:
dcgan.batch_test(FLAGS.checkpoint_dir, FLAGS.file_name)
# dcgan.load(FLAGS.checkpoint_dir)
# dcgan.single_test(FLAGS.checkpoint_dir)
# dcgan.batch_test(FLAGS.checkpoint_dir)
"""
if FLAGS.visualize:
to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
[dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
[dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
[dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
[dcgan.h4_w, dcgan.h4_b, None])
# Below is codes for visualization
OPTION = 2
visualize(sess, dcgan, FLAGS, OPTION)
"""
def main(_):
with tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_dev_placement)) as sess:
dcgan = DCGAN(sess, batch_size=FLAGS.batch_size,
#in_dim=[28,28,1], z_dim=100)
in_dim=[112,112,3], z_dim=100)
dcgan.train(FLAGS)
def train(epoch = 25, learning_rate = 0.0002, beta1 = 0.5, train_size = np.inf, batch_size = 64, input_height = 108, input_width = None, output_height = 64, output_width = None, dataset = 'celebA', input_fname_pattern = '*.jpg', checkpoint_dir = 'checkpoints', sample_dir = 'samples', output_dir = 'output', crop = True, model_dir = 'temp', model_filename = 'dcgan'):
#pp.pprint(flags.FLAGS.__flags)
if input_width is None:
input_width = input_height
if output_width is None:
output_width = output_height
#if not os.path.exists(checkpoint_dir):
# os.makedirs(checkpoint_dir)
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth=True
with tf.Session(config=run_config) as sess:
dcgan = DCGAN(
sess,
input_width=input_width,
input_height=input_height,
output_width=output_width,
output_height=output_height,
batch_size=batch_size,
sample_num=batch_size,
dataset_name=dataset,
input_fname_pattern=input_fname_pattern,
crop=crop,
checkpoint_dir=checkpoint_dir,
sample_dir=sample_dir,
output_dir=output_dir)
show_all_variables()
dcgan.train(epoch = epoch, learning_rate = learning_rate, beta1 = beta1, train_size = train_size, batch_size = batch_size, input_height = input_height, input_width = input_width, output_height = output_height, output_width = output_width, dataset = dataset, input_fname_pattern = input_fname_pattern, checkpoint_dir = checkpoint_dir, sample_dir = sample_dir, output_dir = output_dir, train = train, crop = crop)
dcgan.save(model_dir, dcgan.global_training_steps, model_filename)
def main(_):
pp.pprint(flags.FLAGS.__flags)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
with tf.Session() as sess:
dcgan = DCGAN(sess, image_size = FLAGS.image_size, output_size = FLAGS.output_size, batch_size=FLAGS.batch_size, sample_size = FLAGS.sample_size)
if FLAGS.is_train:
dcgan.train(FLAGS)
else:
dcgan.load(FLAGS.checkpoint_dir)
if FLAGS.visualize:
# to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
# [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
# [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
# [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
# [dcgan.h4_w, dcgan.h4_b, None])
# Below is codes for visualization
OPTION = 2
visualize(sess, dcgan, FLAGS, OPTION)
def main(_):
pp.pprint(flags.FLAGS.__flags)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
with tf.Session(config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=False)) as sess:
if FLAGS.dataset == 'mnist':
assert False
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,
sample_size = 64,
z_dim = 8192,
d_label_smooth = .25,
generator_target_prob = .75 / 2.,
out_stddev = .075,
out_init_b = - .45,
image_shape=[FLAGS.image_width, FLAGS.image_width, 3],
dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
generator=Generator(),
train_func=train, discriminator_func=discriminator,
build_model_func=build_model, config=FLAGS,
devices=["gpu:0", "gpu:1", "gpu:2", "gpu:3"] #, "gpu:4"]
)
if FLAGS.is_train:
print "TRAINING"
dcgan.train(FLAGS)
print "DONE TRAINING"
else:
dcgan.load(FLAGS.checkpoint_dir)
OPTION = 2
visualize(sess, dcgan, FLAGS, OPTION)
def run(checkpoint_dir = 'checkpoints', batch_size = 64, input_height = 108, input_width = None, output_height = 64, output_width = None, dataset = 'celebA', input_fname_pattern = '*.jpg', output_dir = 'output', sample_dir = 'samples', crop=True):
#pp.pprint(flags.FLAGS.__flags)
if input_width is None:
input_width = input_height
if output_width is None:
output_width = output_height
#if not os.path.exists(checkpoint_dir):
# os.makedirs(checkpoint_dir)
#if not os.path.exists(output_dir):
# os.makedirs(output_dir)
#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth=True
with tf.Session(config=run_config) as sess:
dcgan = DCGAN(
sess,
input_width=input_width,
input_height=input_height,
output_width=output_width,
output_height=output_height,
batch_size=batch_size,
sample_num=batch_size,
dataset_name=dataset,
input_fname_pattern=input_fname_pattern,
crop=crop,
checkpoint_dir=checkpoint_dir,
sample_dir=sample_dir,
output_dir=output_dir)
show_all_variables()
try:
tf.global_variables_initializer().run()
except:
tf.initialize_all_variables().run()
# Below is code for visualization
visualize(sess, dcgan, batch_size = batch_size, input_height = input_height, input_width = input_width, output_dir = output_dir)
def main(_):
pp.pprint(flags.FLAGS.__flags)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
# Do not take all memory
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.30)
# sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
# w/ y label
if FLAGS.dataset == 'mnist':
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, y_dim=10, output_size=28,
c_dim=1, dataset_name=FLAGS.dataset,
checkpoint_dir=FLAGS.checkpoint_dir)
# w/o y label
else:
if FLAGS.dataset == 'cityscapes':
print 'Select CITYSCAPES'
mask_dir = CITYSCAPES_mask_dir
syn_dir = CITYSCAPES_syn_dir_2
FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 192, 512, False
FLAGS.dataset_dir = CITYSCAPES_dir
elif FLAGS.dataset == 'inria':
print 'Select INRIAPerson'
FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 160, 96, False
FLAGS.dataset_dir = INRIA_dir
discriminator = Discriminator(sess, batch_size=FLAGS.batch_size, output_size_h=FLAGS.output_size_h, output_size_w=FLAGS.output_size_w, c_dim=FLAGS.c_dim,
dataset_name=FLAGS.dataset,
checkpoint_dir=FLAGS.checkpoint_dir, dataset_dir=FLAGS.dataset_dir)
if FLAGS.mode == 'test':
print('Testing!')
discriminator.test(FLAGS, syn_dir)
elif FLAGS.mode == 'train':
print('Train!')
discriminator.train(FLAGS, syn_dir)
elif FLAGS.mode == 'complete':
print('Complete!')
def main(_):
pp.pprint(flags.FLAGS.__flags)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
# Do not take all memory
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.80)
# sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
# w/ y label
if FLAGS.dataset == 'mnist':
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, y_dim=10, output_size=28,
c_dim=1, dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop,
checkpoint_dir=FLAGS.checkpoint_dir)
# w/o y label
else:
if FLAGS.dataset == 'cityscapes':
print 'Select CITYSCAPES'
mask_dir = CITYSCAPES_mask_dir
FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 192, 512, False
FLAGS.dataset_dir = CITYSCAPES_dir
elif FLAGS.dataset == 'inria':
print 'Select INRIAPerson'
FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 160, 96, False
FLAGS.dataset_dir = INRIA_dir
elif FLAGS.dataset == 'indoor':
print 'Select indoor'
FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 256, 256, False
FLAGS.dataset_dir = indoor_dir
elif FLAGS.dataset == 'indoor_bedroom':
print 'Select indoor bedroom'
FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 256, 256, False
FLAGS.dataset_dir = indoor_bedroom_dir
elif FLAGS.dataset == 'indoor_dining':
print 'Select indoor dining'
FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 256, 256, False
FLAGS.dataset_dir = indoor_bedroom_dir
dcgan = DCGAN(sess, batch_size=FLAGS.batch_size, output_size_h=FLAGS.output_size_h, output_size_w=FLAGS.output_size_w, c_dim=FLAGS.c_dim,
dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop,
checkpoint_dir=FLAGS.checkpoint_dir, dataset_dir=FLAGS.dataset_dir)
if FLAGS.mode == 'test':
print('Testing!')
dcgan.test(FLAGS)
elif FLAGS.mode == 'train':
print('Train!')
dcgan.train(FLAGS)
elif FLAGS.mode == 'complete':
print('Complete!')
dcgan.complete(FLAGS, mask_dir)