def hinge_loss(positive_predictions, negative_predictions, mask=None):
"""
Hinge pairwise loss function.
Parameters
----------
positive_predictions: tensor
Tensor containing predictions for known positive items.
negative_predictions: tensor
Tensor containing predictions for sampled negative items.
mask: tensor, optional
A binary tensor used to zero the loss from some entries
of the loss tensor.
Returns
-------
loss, float
The mean value of the loss function.
"""
loss = torch.clamp(negative_predictions -
positive_predictions +
1.0, 0.0)
if mask is not None:
mask = mask.float()
loss = loss * mask
return loss.sum() / mask.sum()
return loss.mean()
评论列表
文章目录