losses.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:spotlight 作者: maciejkula 项目源码 文件源码
def hinge_loss(positive_predictions, negative_predictions, mask=None):
    """
    Hinge pairwise loss function.

    Parameters
    ----------

    positive_predictions: tensor
        Tensor containing predictions for known positive items.
    negative_predictions: tensor
        Tensor containing predictions for sampled negative items.
    mask: tensor, optional
        A binary tensor used to zero the loss from some entries
        of the loss tensor.

    Returns
    -------

    loss, float
        The mean value of the loss function.
    """

    loss = torch.clamp(negative_predictions -
                       positive_predictions +
                       1.0, 0.0)

    if mask is not None:
        mask = mask.float()
        loss = loss * mask
        return loss.sum() / mask.sum()

    return loss.mean()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号