def softmax(x, dim=-1):
"""
TODO: change to use the default pyTorch implementation when available
Source: https://discuss.pytorch.org/t/why-softmax-function-cant-specify-the-dimension-to-operate/2637
:param x: tensor
:param dim: Dimension to apply the softmax function to. The elements of the tensor in this
dimension must sum to 1.
:return: tensor having the same dimension as `x` rescaled along dim
"""
input_size = x.size()
trans_input = x.transpose(dim, len(input_size) - 1)
trans_size = trans_input.size()
input_2d = trans_input.contiguous().view(-1, trans_size[-1])
try:
soft_max_2d = F.softmax(input_2d, 1)
except TypeError:
# Support older pytorch 0.2 release.
soft_max_2d = F.softmax(input_2d)
soft_max_nd = soft_max_2d.view(*trans_size)
return soft_max_nd.transpose(dim, len(input_size) - 1)
评论列表
文章目录