util.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:allennlp 作者: allenai 项目源码 文件源码
def _last_dimension_applicator(function_to_apply: Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
                               tensor: torch.Tensor,
                               mask: Optional[torch.Tensor] = None):
    """
    Takes a tensor with 3 or more dimensions and applies a function over the last dimension.  We
    assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given)
    has shape ``(batch_size, sequence_length)``.  We first unsqueeze and expand the mask so that it
    has the same shape as the tensor, then flatten them both to be 2D, pass them through
    the function and put the tensor back in its original shape.
    """
    tensor_shape = tensor.size()
    reshaped_tensor = tensor.view(-1, tensor.size()[-1])
    if mask is not None:
        while mask.dim() < tensor.dim():
            mask = mask.unsqueeze(1)
        mask = mask.expand_as(tensor).contiguous().float()
        mask = mask.view(-1, mask.size()[-1])
    reshaped_result = function_to_apply(reshaped_tensor, mask)
    return reshaped_result.view(*tensor_shape)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号