def main():
# Games that we'll be testing.
game_to_ID = {'BeamRider':0,
'Breakout':1,
'Enduro':2,
'Pong':3,
'Qbert':4}
# Get some arguments here. Note: num_timesteps default uses tasks default.
parser = argparse.ArgumentParser()
parser.add_argument('--game', type=str, default='Pong')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--num_timesteps', type=int, default=40000000)
args = parser.parse_args()
# Choose the game to play and set log file.
benchmark = gym.benchmark_spec('Atari40M')
task = benchmark.tasks[game_to_ID[args.game]]
log_name = args.game+"_s"+str(args.seed).zfill(3)+".pkl"
# Run training. Should change the seed if possible!
# Also, the actual # of iterations run is _roughly_ num_timesteps/4.
seed = args.seed
env = get_env(task, seed)
session = get_session()
print("task = {}".format(task))
atari_learn(env,
session,
num_timesteps=args.num_timesteps,
log_file=log_name)
评论列表
文章目录