def test_sarsa():
env = TwoRoundDeterministicRewardEnv()
np.random.seed(123)
env.seed(123)
random.seed(123)
nb_actions = env.action_space.n
# Next, we build a very simple model.
model = Sequential()
model.add(Dense(16, input_shape=(1,)))
model.add(Activation('relu'))
model.add(Dense(nb_actions, activation='linear'))
policy = EpsGreedyQPolicy(eps=.1)
sarsa = SARSAAgent(model=model, nb_actions=nb_actions, nb_steps_warmup=50, policy=policy)
sarsa.compile(Adam(lr=1e-3))
sarsa.fit(env, nb_steps=20000, visualize=False, verbose=0)
policy.eps = 0.
h = sarsa.test(env, nb_episodes=20, visualize=False)
assert_allclose(np.mean(h.history['episode_reward']), 3.)
评论列表
文章目录