def test_value(self):
with self.test_session(use_gpu=True):
def _test_value(logits, given):
logits = np.array(logits, np.float32)
normalized_logits = logits - misc.logsumexp(
logits, axis=-1, keepdims=True)
given = np.array(given, np.int32)
cat = OnehotCategorical(logits)
log_p = cat.log_prob(tf.one_hot(given, logits.shape[-1],
dtype=tf.int32))
def _one_hot(x, depth):
n_elements = x.size
ret = np.zeros((n_elements, depth))
ret[np.arange(n_elements), x.flat] = 1
return ret.reshape(list(x.shape) + [depth])
target_log_p = np.sum(_one_hot(
given, logits.shape[-1]) * normalized_logits, -1)
self.assertAllClose(log_p.eval(), target_log_p)
p = cat.prob(tf.one_hot(given, logits.shape[-1],
dtype=tf.int32))
target_p = np.sum(_one_hot(
given, logits.shape[-1]) * np.exp(normalized_logits), -1)
self.assertAllClose(p.eval(), target_p)
_test_value([0.], [0, 0, 0])
_test_value([-50., -10., -50.], [0, 1, 2, 1])
_test_value([0., 4.], [[0, 1], [0, 1]])
_test_value([[2., 3., 1.], [5., 7., 4.]],
np.ones([3, 1, 1], dtype=np.int32))
评论列表
文章目录