misc.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:verb-attributes 作者: uwnlp 项目源码 文件源码
def cosine_ranking_loss(input_data, ctx, margin=0.1):
    """
    :param input_data: [batch_size, 300] tensor of predictions
    :param ctx: [batch_size, 300] tensor of ground truths
    :param margin: Difference between them
    :return: 
    """
    normed = _normalize(input_data)
    ctx_normed = _normalize(ctx)
    shuff_inds = torch.randperm(normed.size(0))
    if ctx.is_cuda:
        shuff_inds = shuff_inds.cuda()
    shuff = ctx_normed[shuff_inds]

    correct_contrib = torch.sum(normed * ctx_normed, 1).squeeze()
    incorrect_contrib = torch.sum(normed * shuff, 1).squeeze()

    # similarity = torch.mm(normed, ctx_normed.t()) #[predictions, gts]
    # correct_contrib = similarity.diag()
    # incorrect_contrib = incorrect_contrib.sum(1).squeeze()/(incorrect_contrib.size(1)-1.0)
    #
    cost = (0.1 + incorrect_contrib-correct_contrib).clamp(min=0)

    return cost, correct_contrib, incorrect_contrib
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号