def _mask_weights(mask=None, weights=None):
"""Mask a given set of weights.
Elements are included when the corresponding `mask` element is `False`, and
excluded otherwise.
Args:
mask: An optional, `bool` `Tensor`.
weights: An optional `Tensor` whose shape matches `mask` if `mask` is not
`None`.
Returns:
Masked weights if `mask` and `weights` are not `None`, weights equivalent to
`mask` if `weights` is `None`, and otherwise `weights`.
Raises:
ValueError: If `weights` and `mask` are not `None` and have mismatched
shapes.
"""
if mask is not None:
check_ops.assert_type(mask, dtypes.bool)
if weights is None:
weights = array_ops.ones_like(mask, dtype=dtypes.float32)
weights = math_ops.cast(math_ops.logical_not(mask), weights.dtype) * weights
return weights
评论列表
文章目录