metric_ops.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _mask_weights(mask=None, weights=None):
  """Mask a given set of weights.

  Elements are included when the corresponding `mask` element is `False`, and
  excluded otherwise.

  Args:
    mask: An optional, `bool` `Tensor`.
    weights: An optional `Tensor` whose shape matches `mask` if `mask` is not
      `None`.

  Returns:
    Masked weights if `mask` and `weights` are not `None`, weights equivalent to
    `mask` if `weights` is `None`, and otherwise `weights`.

  Raises:
    ValueError: If `weights` and `mask` are not `None` and have mismatched
      shapes.
  """
  if mask is not None:
    check_ops.assert_type(mask, dtypes.bool)
    if weights is None:
      weights = array_ops.ones_like(mask, dtype=dtypes.float32)
    weights = math_ops.cast(math_ops.logical_not(mask), weights.dtype) * weights

  return weights
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号