def _kl_categorical_categorical(a, b, name=None):
"""Calculate the batched KL divergence KL(a || b) with a and b Categorical.
Args:
a: instance of a Categorical distribution object.
b: instance of a Categorical distribution object.
name: (optional) Name to use for created operations.
default is "kl_categorical_categorical".
Returns:
Batchwise KL(a || b)
"""
with ops.name_scope(
name, "kl_categorical_categorical", [a.logits, b.logits]):
# sum(p*ln(p/q))
return math_ops.reduce_sum(
nn_ops.softmax(a.logits)*(nn_ops.log_softmax(a.logits)
- nn_ops.log_softmax(b.logits)), reduction_indices=[-1])
评论列表
文章目录