def test_bernoulli_variable(self):
# TODO: remove once we merge Variable and Tensor
t = torch.autograd.Variable(torch.ByteTensor(10, 10))
def isBinary(t):
return torch.ne(t, 0).mul_(torch.ne(t, 1)).sum() == 0
p = 0.5
t.bernoulli_(p)
self.assertTrue(isBinary(t))
p = torch.autograd.Variable(torch.rand(10))
t.bernoulli_(p)
self.assertTrue(isBinary(t))
q = torch.rand(5, 5)
self.assertTrue(isBinary(q.bernoulli()))
评论列表
文章目录