losses.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:spotlight 作者: maciejkula 项目源码 文件源码
def poisson_loss(observed_ratings, predicted_ratings):
    """
    Poisson loss.

    Parameters
    ----------

    observed_ratings: tensor
        Tensor containing observed ratings.
    predicted_ratings: tensor
        Tensor containing rating predictions.

    Returns
    -------

    loss, float
        The mean value of the loss function.
    """

    assert_no_grad(observed_ratings)

    return (predicted_ratings - observed_ratings * torch.log(predicted_ratings)).mean()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号