bernoulli.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def _kl_bernoulli_bernoulli(a, b, name=None):
  """Calculate the batched KL divergence KL(a || b) with a and b Bernoulli.

  Args:
    a: instance of a Bernoulli distribution object.
    b: instance of a Bernoulli distribution object.
    name: (optional) Name to use for created operations.
      default is "kl_bernoulli_bernoulli".

  Returns:
    Batchwise KL(a || b)
  """
  with ops.name_scope(name, "kl_bernoulli_bernoulli",
                      values=[a.logits, b.logits]):
    delta_probs0 = nn.softplus(-b.logits) - nn.softplus(-a.logits)
    delta_probs1 = nn.softplus(b.logits) - nn.softplus(a.logits)
    return (math_ops.sigmoid(a.logits) * delta_probs0
            + math_ops.sigmoid(-a.logits) * delta_probs1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号