def _kl_bernoulli_bernoulli(a, b, name=None):
"""Calculate the batched KL divergence KL(a || b) with a and b Bernoulli.
Args:
a: instance of a Bernoulli distribution object.
b: instance of a Bernoulli distribution object.
name: (optional) Name to use for created operations.
default is "kl_bernoulli_bernoulli".
Returns:
Batchwise KL(a || b)
"""
with ops.name_scope(name, "kl_bernoulli_bernoulli",
values=[a.logits, b.logits]):
delta_probs0 = nn.softplus(-b.logits) - nn.softplus(-a.logits)
delta_probs1 = nn.softplus(b.logits) - nn.softplus(a.logits)
return (math_ops.sigmoid(a.logits) * delta_probs0
+ math_ops.sigmoid(-a.logits) * delta_probs1)
bernoulli.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录