def main(_):
# preprocess
conf.observation_dims = eval(conf.observation_dims)
# start
gpu_options = tf.GPUOptions(
per_process_gpu_memory_fraction=calc_gpu_fraction(conf.gpu_fraction))
dataset = data_loader(conf.source_path, conf.target_path)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
env = Curve()
pred_network = CNN(sess=sess,
observation_dims=conf.observation_dims,
name='pred_network',
trainable=True)
policy = Policy(sess=sess,
pred_network=pred_network,
env=env,
dataset=dataset,
conf=conf)
if conf.is_train:
policy.train()
else:
policy.test(conf.test_image_path)
评论列表
文章目录