def run(checkpoint_dir = 'checkpoints', batch_size = 64, input_height = 108, input_width = None, output_height = 64, output_width = None, dataset = 'celebA', input_fname_pattern = '*.jpg', output_dir = 'output', sample_dir = 'samples', crop=True):
#pp.pprint(flags.FLAGS.__flags)
if input_width is None:
input_width = input_height
if output_width is None:
output_width = output_height
#if not os.path.exists(checkpoint_dir):
# os.makedirs(checkpoint_dir)
#if not os.path.exists(output_dir):
# os.makedirs(output_dir)
#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth=True
with tf.Session(config=run_config) as sess:
dcgan = DCGAN(
sess,
input_width=input_width,
input_height=input_height,
output_width=output_width,
output_height=output_height,
batch_size=batch_size,
sample_num=batch_size,
dataset_name=dataset,
input_fname_pattern=input_fname_pattern,
crop=crop,
checkpoint_dir=checkpoint_dir,
sample_dir=sample_dir,
output_dir=output_dir)
show_all_variables()
try:
tf.global_variables_initializer().run()
except:
tf.initialize_all_variables().run()
# Below is code for visualization
visualize(sess, dcgan, batch_size = batch_size, input_height = input_height, input_width = input_width, output_dir = output_dir)
评论列表
文章目录