def ours_train(m, x, labels, data, att_crit=None, optimizers=None):
"""
Train the direct attribute prediction model
:param m: Model we're using
:param x: [batch_size, 3, 224, 224] Image input
:param labels: [batch_size] variable with indices of the right verbs
:param embeds: [vocab_size, 300] Variables with embeddings of all of the verbs
:param atts_matrix: [vocab_size, att_dim] matrix with GT attributes of the verbs
:param att_crit: AttributeLoss module that computes the loss
:param optimizers: the decorator will use these to update parameters
:return:
"""
logits = ours_logits(m, x, data, att_crit=att_crit)
loss = m.l2_penalty
if len(logits) == 1:
loss += F.cross_entropy(logits[0], labels, size_average=True)
else:
sum_logits = sum(logits)
for l in logits:
loss += F.cross_entropy(l, labels, size_average=True)/(len(logits)+1)
loss += F.cross_entropy(sum_logits, labels, size_average=True)/(len(logits)+1)
return loss
评论列表
文章目录