def train(epoch = 25, learning_rate = 0.0002, beta1 = 0.5, train_size = np.inf, batch_size = 64, input_height = 108, input_width = None, output_height = 64, output_width = None, dataset = 'celebA', input_fname_pattern = '*.jpg', checkpoint_dir = 'checkpoints', sample_dir = 'samples', output_dir = 'output', crop = True, model_dir = 'temp', model_filename = 'dcgan'):
#pp.pprint(flags.FLAGS.__flags)
if input_width is None:
input_width = input_height
if output_width is None:
output_width = output_height
#if not os.path.exists(checkpoint_dir):
# os.makedirs(checkpoint_dir)
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth=True
with tf.Session(config=run_config) as sess:
dcgan = DCGAN(
sess,
input_width=input_width,
input_height=input_height,
output_width=output_width,
output_height=output_height,
batch_size=batch_size,
sample_num=batch_size,
dataset_name=dataset,
input_fname_pattern=input_fname_pattern,
crop=crop,
checkpoint_dir=checkpoint_dir,
sample_dir=sample_dir,
output_dir=output_dir)
show_all_variables()
dcgan.train(epoch = epoch, learning_rate = learning_rate, beta1 = beta1, train_size = train_size, batch_size = batch_size, input_height = input_height, input_width = input_width, output_height = output_height, output_width = output_width, dataset = dataset, input_fname_pattern = input_fname_pattern, checkpoint_dir = checkpoint_dir, sample_dir = sample_dir, output_dir = output_dir, train = train, crop = crop)
dcgan.save(model_dir, dcgan.global_training_steps, model_filename)
评论列表
文章目录