def main(_):
gpu_options = tf.GPUOptions(
per_process_gpu_memory_fraction=calc_gpu_fraction(FLAGS.gpu_fraction))
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
config = get_config(FLAGS) or FLAGS
if config.env_type == 'simple':
env = SimpleGymEnvironment(config)
else:
env = GymEnvironment(config)
if not tf.test.is_gpu_available() and FLAGS.use_gpu:
raise Exception("use_gpu flag is true when no GPUs are available")
if not FLAGS.use_gpu:
config.cnn_format = 'NHWC'
agent = Agent(config, env, sess)
if FLAGS.is_train:
agent.train()
else:
agent.play()
评论列表
文章目录