test_nn.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def nllloss_no_reduce_weights_ignore_index_test():
    t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
    weight = torch.rand(10)

    def kwargs(i):
        return {'weight': weight.type_as(i), 'reduce': False,
                'ignore_index': 2}

    return dict(
        fullname='NLLLoss_no_reduce_weights_ignore_index',
        constructor=wrap_functional(
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))),
        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
        reference_fn=lambda i, _:
            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
        pickle=False)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号