losses.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:spotlight 作者: maciejkula 项目源码 文件源码
def logistic_loss(observed_ratings, predicted_ratings):
    """
    Logistic loss for explicit data.

    Parameters
    ----------

    observed_ratings: tensor
        Tensor containing observed ratings which
        should be +1 or -1 for this loss function.
    predicted_ratings: tensor
        Tensor containing rating predictions.

    Returns
    -------

    loss, float
        The mean value of the loss function.
    """

    assert_no_grad(observed_ratings)

    # Convert target classes from (-1, 1) to (0, 1)
    observed_ratings = torch.clamp(observed_ratings, 0, 1)

    return F.binary_cross_entropy_with_logits(predicted_ratings,
                                              observed_ratings,
                                              size_average=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号