main.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:photo-editing-tensorflow 作者: JamesChuanggg 项目源码 文件源码
def main(_):
  # preprocess
  conf.observation_dims = eval(conf.observation_dims)

  # start
  gpu_options = tf.GPUOptions(
      per_process_gpu_memory_fraction=calc_gpu_fraction(conf.gpu_fraction))

  dataset = data_loader(conf.source_path, conf.target_path)

  with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    env = Curve()

    pred_network = CNN(sess=sess,
                       observation_dims=conf.observation_dims,
                       name='pred_network', 
               trainable=True)

    policy = Policy(sess=sess, 
            pred_network=pred_network,
            env=env,
            dataset=dataset,
            conf=conf)

    if conf.is_train:
        policy.train()
    else:
    policy.test(conf.test_image_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号