util.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:allennlp 作者: allenai 项目源码 文件源码
def replace_masked_values(tensor: Variable, mask: Variable, replace_with: float) -> Variable:
    """
    Replaces all masked values in ``tensor`` with ``replace_with``.  ``mask`` must be broadcastable
    to the same shape as ``tensor``. We require that ``tensor.dim() == mask.dim()``, as otherwise we
    won't know which dimensions of the mask to unsqueeze.
    """
    # We'll build a tensor of the same shape as `tensor`, zero out masked values, then add back in
    # the `replace_with` value.
    if tensor.dim() != mask.dim():
        raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim()))
    one_minus_mask = 1.0 - mask
    values_to_add = replace_with * one_minus_mask
    return tensor * mask + values_to_add
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号