categorical.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _kl_categorical_categorical(a, b, name=None):
  """Calculate the batched KL divergence KL(a || b) with a and b Categorical.

  Args:
    a: instance of a Categorical distribution object.
    b: instance of a Categorical distribution object.
    name: (optional) Name to use for created operations.
      default is "kl_categorical_categorical".

  Returns:
    Batchwise KL(a || b)
  """
  with ops.name_scope(
    name, "kl_categorical_categorical", [a.logits, b.logits]):
    # sum(p*ln(p/q))
    return math_ops.reduce_sum(
        nn_ops.softmax(a.logits)*(nn_ops.log_softmax(a.logits)
            - nn_ops.log_softmax(b.logits)), reduction_indices=[-1])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号