bernoulli.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _log_prob(self, event):
    # TODO(jaana): The current sigmoid_cross_entropy_with_logits has
    # inconsistent  behavior for logits = inf/-inf.
    event = ops.convert_to_tensor(event, name="event")
    event = math_ops.cast(event, self.logits.dtype)
    logits = self.logits
    # sigmoid_cross_entropy_with_logits doesn't broadcast shape,
    # so we do this here.
    # TODO(b/30637701): Check dynamic shape, and don't broadcast if the
    # dynamic shapes are the same.
    if (not event.get_shape().is_fully_defined() or
        not logits.get_shape().is_fully_defined() or
        event.get_shape() != logits.get_shape()):
      logits = array_ops.ones_like(event) * logits
      event = array_ops.ones_like(logits) * event
    return -nn.sigmoid_cross_entropy_with_logits(logits, event)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号