def logistic_loss(observed_ratings, predicted_ratings):
"""
Logistic loss for explicit data.
Parameters
----------
observed_ratings: tensor
Tensor containing observed ratings which
should be +1 or -1 for this loss function.
predicted_ratings: tensor
Tensor containing rating predictions.
Returns
-------
loss, float
The mean value of the loss function.
"""
assert_no_grad(observed_ratings)
# Convert target classes from (-1, 1) to (0, 1)
observed_ratings = torch.clamp(observed_ratings, 0, 1)
return F.binary_cross_entropy_with_logits(predicted_ratings,
observed_ratings,
size_average=True)
评论列表
文章目录