def test_value(self):
with self.test_session(use_gpu=True):
def _test_value(given, temperature, logits):
given = np.array(given, np.float32)
logits = np.array(logits, np.float32)
n = logits.shape[-1]
t = temperature
target_log_p = special.gammaln(n) + (n - 1) * np.log(t) + \
(logits - (t + 1) * np.log(given)).sum(axis=-1) - \
n * np.log(np.exp(logits - t * np.log(given)).sum(axis=-1))
con = Concrete(temperature, logits=logits)
log_p = con.log_prob(given)
self.assertAllClose(log_p.eval(), target_log_p)
p = con.prob(given)
self.assertAllClose(p.eval(), np.exp(target_log_p))
_test_value([0.25, 0.25, 0.5],
0.1,
[1., 1., 1.2])
_test_value([[0.25, 0.25, 0.5],
[0.1, 0.5, 0.4]],
0.5,
[[1., 1., 1.], [.5, .5, .4]])
评论列表
文章目录