def nllloss_no_reduce_weights_ignore_index_neg_test():
t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
weight = torch.rand(10)
def kwargs(i):
return {'weight': weight.type_as(i), 'reduce': False,
'ignore_index': -1}
return dict(
fullname='NLLLoss_no_reduce_weights_ignore_index_neg',
constructor=wrap_functional(
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))),
input=torch.rand(15, 10).add(1e-2).log(),
reference_fn=lambda i, _:
loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
pickle=False)
评论列表
文章目录