def get_attr_loss(output, attributes, flip, params):
"""
Compute attributes loss.
"""
assert type(flip) is bool
k = 0
loss = 0
for (_, n_cat) in params.attr:
# categorical
x = output[:, k:k + n_cat].contiguous()
y = attributes[:, k:k + n_cat].max(1)[1].view(-1)
if flip:
# generate different categories
shift = torch.LongTensor(y.size()).random_(n_cat - 1) + 1
y = (y + Variable(shift.cuda())) % n_cat
loss += F.cross_entropy(x, y)
k += n_cat
return loss
评论列表
文章目录