main.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:easygen 作者: markriedl 项目源码 文件源码
def train(epoch = 25, learning_rate = 0.0002, beta1 = 0.5, train_size = np.inf, batch_size = 64, input_height = 108, input_width = None, output_height = 64, output_width = None, dataset = 'celebA', input_fname_pattern = '*.jpg', checkpoint_dir = 'checkpoints', sample_dir = 'samples', output_dir = 'output', crop = True, model_dir = 'temp', model_filename = 'dcgan'):
  #pp.pprint(flags.FLAGS.__flags)

  if input_width is None:
    input_width = input_height
  if output_width is None:
    output_width = output_height

  #if not os.path.exists(checkpoint_dir):
  #  os.makedirs(checkpoint_dir)
  if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
  if not os.path.exists(output_dir):
    os.makedirs(output_dir)

  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True

  with tf.Session(config=run_config) as sess:
    dcgan = DCGAN(
        sess,
        input_width=input_width,
        input_height=input_height,
        output_width=output_width,
        output_height=output_height,
        batch_size=batch_size,
        sample_num=batch_size,
        dataset_name=dataset,
        input_fname_pattern=input_fname_pattern,
        crop=crop,
        checkpoint_dir=checkpoint_dir,
        sample_dir=sample_dir,
        output_dir=output_dir)

    show_all_variables()

    dcgan.train(epoch = epoch, learning_rate = learning_rate, beta1 = beta1, train_size = train_size, batch_size = batch_size, input_height = input_height, input_width = input_width, output_height = output_height, output_width = output_width, dataset = dataset, input_fname_pattern = input_fname_pattern, checkpoint_dir = checkpoint_dir, sample_dir = sample_dir, output_dir = output_dir, train = train, crop = crop)

    dcgan.save(model_dir, dcgan.global_training_steps, model_filename)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号