imsitu_model.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:verb-attributes 作者: uwnlp 项目源码 文件源码
def devise_train(m, x, labels, data, att_crit=None, optimizers=None):
    """
    Train the direct attribute prediction model
    :param m: Model we're using
    :param x: [batch_size, 3, 224, 224] Image input
    :param labels: [batch_size] variable with indices of the right verbs
    :param embeds: [vocab_size, 300] Variables with embeddings of all of the verbs
    :param atts_matrix: [vocab_size, att_dim] matrix with GT attributes of the verbs
    :param att_crit: AttributeLoss module that computes the loss
    :param optimizers: the decorator will use these to update parameters
    :return: 
    """
    # Make embed unit normed
    embed_normed = _normalize(data.attributes.embeds)
    mv_image = m(x).embed_pred
    tmv_image = mv_image @ embed_normed.t()

    # Use a random label from the same batch
    correct_contrib = torch.gather(tmv_image, 1, labels[:,None])

    # Should be fine to ignore where the correct contrib intersects because the gradient
    # wrt input is 0
    losses = (0.1 + tmv_image - correct_contrib.expand_as(tmv_image)).clamp(min=0.0)
    # losses.scatter_(1, labels[:, None], 0.0)
    loss = m.l2_penalty + losses.sum(1).squeeze().mean()
    return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号