def kldivloss_no_reduce_test():
t = Variable(torch.randn(10, 10))
return dict(
fullname='KLDivLoss_no_reduce',
constructor=wrap_functional(
lambda i: F.kl_div(i, t.type_as(i), reduce=False)),
input_fn=lambda: torch.rand(10, 10).log(),
reference_fn=lambda i, _:
loss_reference_fns['KLDivLoss'](i, t.data.type_as(i), reduce=False),
pickle=False)
评论列表
文章目录