util.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def softmax(x, dim=-1):
    """
    TODO: change to use the default pyTorch implementation when available
    Source: https://discuss.pytorch.org/t/why-softmax-function-cant-specify-the-dimension-to-operate/2637
    :param x: tensor
    :param dim: Dimension to apply the softmax function to. The elements of the tensor in this
        dimension must sum to 1.
    :return: tensor having the same dimension as `x` rescaled along dim
    """
    input_size = x.size()

    trans_input = x.transpose(dim, len(input_size) - 1)
    trans_size = trans_input.size()

    input_2d = trans_input.contiguous().view(-1, trans_size[-1])

    try:
        soft_max_2d = F.softmax(input_2d, 1)
    except TypeError:
        # Support older pytorch 0.2 release.
        soft_max_2d = F.softmax(input_2d)

    soft_max_nd = soft_max_2d.view(*trans_size)
    return soft_max_nd.transpose(dim, len(input_size) - 1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号