def listpl(x, t, ?=15.0):
"""
The ListPL loss, a stochastic variant of ListMLE that in expectation
approximates the true ListNet loss.
:param x: The activation of the previous layer
:param t: The target labels
:param ?: The smoothing factor
:return: The loss
"""
# Sample permutation from PL(t)
index = _pl_sample(t, ?)
x = x[index]
# Compute MLE loss
final = logcumsumexp(x)
return F.sum(final - x)
评论列表
文章目录