def _log_prob(self, event):
# TODO(jaana): The current sigmoid_cross_entropy_with_logits has
# inconsistent behavior for logits = inf/-inf.
event = math_ops.cast(event, self.logits.dtype)
logits = self.logits
# sigmoid_cross_entropy_with_logits doesn't broadcast shape,
# so we do this here.
broadcast = lambda logits, event: (
array_ops.ones_like(event) * logits,
array_ops.ones_like(logits) * event)
# First check static shape.
if (event.get_shape().is_fully_defined() and
logits.get_shape().is_fully_defined()):
if event.get_shape() != logits.get_shape():
logits, event = broadcast(logits, event)
else:
logits, event = control_flow_ops.cond(
distribution_util.same_dynamic_shape(logits, event),
lambda: (logits, event),
lambda: broadcast(logits, event))
return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
bernoulli.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录