def bceloss_no_reduce_test():
t = torch.randn(15, 10).gt(0).double()
return dict(
fullname='BCELoss_no_reduce',
constructor=wrap_functional(
lambda i: F.binary_cross_entropy(i, Variable(t.type_as(i.data)), reduce=False)),
input_fn=lambda: torch.rand(15, 10).clamp_(2e-2, 1 - 2e-2),
reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()),
check_gradgrad=False,
pickle=False)
评论列表
文章目录