def test_sequence_cross_entropy_with_logits_averages_batch_correctly(self):
# test batch average is the same as dividing the batch averaged
# loss by the number of batches containing any non-padded tokens.
tensor = torch.rand([5, 7, 4])
tensor[0, 3:, :] = 0
tensor[1, 4:, :] = 0
tensor[2, 2:, :] = 0
tensor[3, :, :] = 0
weights = (tensor != 0.0)[:, :, 0].long().squeeze(-1)
targets = torch.LongTensor(numpy.random.randint(0, 3, [5, 7]))
targets *= weights
tensor = Variable(tensor)
targets = Variable(targets)
weights = Variable(weights)
loss = util.sequence_cross_entropy_with_logits(tensor, targets, weights)
vector_loss = util.sequence_cross_entropy_with_logits(tensor, targets, weights, batch_average=False)
# Batch has one completely padded row, so divide by 4.
assert loss.data.numpy() == vector_loss.data.sum() / 4
评论列表
文章目录