def mellowmax(values, omega=1., axis=1):
"""Mellowmax function.
This is a kind of softmax function that is, unlike the Boltzmann softmax,
non-expansion.
See: http://arxiv.org/abs/1612.05628
Args:
values (Variable or ndarray):
Input values. Mellowmax is taken along the second axis.
omega (float):
Parameter of mellowmax.
axis (int):
Axis along which mellowmax is taken.
Returns:
outputs (Variable)
"""
n = values.shape[axis]
return (F.logsumexp(omega * values, axis=axis) - np.log(n)) / omega
评论列表
文章目录