def replace_masked_values(tensor: Variable, mask: Variable, replace_with: float) -> Variable:
"""
Replaces all masked values in ``tensor`` with ``replace_with``. ``mask`` must be broadcastable
to the same shape as ``tensor``. We require that ``tensor.dim() == mask.dim()``, as otherwise we
won't know which dimensions of the mask to unsqueeze.
"""
# We'll build a tensor of the same shape as `tensor`, zero out masked values, then add back in
# the `replace_with` value.
if tensor.dim() != mask.dim():
raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim()))
one_minus_mask = 1.0 - mask
values_to_add = replace_with * one_minus_mask
return tensor * mask + values_to_add
评论列表
文章目录