def test_compute_advantage(self):
sample_actions = np.random.randint(self.action_size,
size=self.batch_size)
greedy_actions = self.q_values.argmax(axis=1)
ret = self.qout.compute_advantage(sample_actions)
self.assertIsInstance(ret, chainer.Variable)
for b in range(self.batch_size):
if sample_actions[b] == greedy_actions[b]:
self.assertAlmostEqual(ret.data[b], 0)
else:
# An advantage to the optimal policy must be always negative
self.assertLess(ret.data[b], 0)
q = self.q_values[b, sample_actions[b]]
v = self.q_values[b, greedy_actions[b]]
adv = q - v
self.assertAlmostEqual(ret.data[b], adv)
评论列表
文章目录