def softmax(x: T.Tensor) -> T.Tensor:
"""
Softmax function on a tensor.
Exponentiaties the tensor elementwise and divides
by the sum along axis=1.
Args:
x: A tensor.
Returns:
tensor: Softmax of the tensor.
"""
xreg = matrix.subtract(matrix.tmax(x, axis=1, keepdims=True), x)
y = torch.exp(xreg)
return matrix.divide(matrix.tsum(y, axis=1, keepdims=True), y)
评论列表
文章目录