util_test.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:allennlp 作者: allenai 项目源码 文件源码
def test_sequence_cross_entropy_with_logits_averages_batch_correctly(self):
        # test batch average is the same as dividing the batch averaged
        # loss by the number of batches containing any non-padded tokens.
        tensor = torch.rand([5, 7, 4])
        tensor[0, 3:, :] = 0
        tensor[1, 4:, :] = 0
        tensor[2, 2:, :] = 0
        tensor[3, :, :] = 0
        weights = (tensor != 0.0)[:, :, 0].long().squeeze(-1)
        targets = torch.LongTensor(numpy.random.randint(0, 3, [5, 7]))
        targets *= weights

        tensor = Variable(tensor)
        targets = Variable(targets)
        weights = Variable(weights)
        loss = util.sequence_cross_entropy_with_logits(tensor, targets, weights)

        vector_loss = util.sequence_cross_entropy_with_logits(tensor, targets, weights, batch_average=False)
        # Batch has one completely padded row, so divide by 4.
        assert loss.data.numpy() == vector_loss.data.sum() / 4
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号