def devise_train(m, x, labels, data, att_crit=None, optimizers=None):
"""
Train the direct attribute prediction model
:param m: Model we're using
:param x: [batch_size, 3, 224, 224] Image input
:param labels: [batch_size] variable with indices of the right verbs
:param embeds: [vocab_size, 300] Variables with embeddings of all of the verbs
:param atts_matrix: [vocab_size, att_dim] matrix with GT attributes of the verbs
:param att_crit: AttributeLoss module that computes the loss
:param optimizers: the decorator will use these to update parameters
:return:
"""
# Make embed unit normed
embed_normed = _normalize(data.attributes.embeds)
mv_image = m(x).embed_pred
tmv_image = mv_image @ embed_normed.t()
# Use a random label from the same batch
correct_contrib = torch.gather(tmv_image, 1, labels[:,None])
# Should be fine to ignore where the correct contrib intersects because the gradient
# wrt input is 0
losses = (0.1 + tmv_image - correct_contrib.expand_as(tmv_image)).clamp(min=0.0)
# losses.scatter_(1, labels[:, None], 0.0)
loss = m.l2_penalty + losses.sum(1).squeeze().mean()
return loss
评论列表
文章目录