def _safe_scalar_div(numerator, denominator, name):
"""Divides two values, returning 0 if the denominator is 0.
Args:
numerator: A scalar `float64` `Tensor`.
denominator: A scalar `float64` `Tensor`.
name: Name for the returned op.
Returns:
0 if `denominator` == 0, else `numerator` / `denominator`
"""
numerator.get_shape().with_rank_at_most(1)
denominator.get_shape().with_rank_at_most(1)
return control_flow_ops.cond(
math_ops.equal(
array_ops.constant(0.0, dtype=dtypes.float64), denominator),
lambda: array_ops.constant(0.0, dtype=dtypes.float64),
lambda: math_ops.div(numerator, denominator),
name=name)
评论列表
文章目录