def test_erfinv(self):
def checkType(tensor):
inputValues = torch.randn(4, 4, out=tensor()).clamp(-2., 2.)
self.assertEqual(tensor(inputValues).erf().erfinv(), tensor(inputValues))
# test inf
self.assertTrue(torch.equal(tensor([-1, 1]).erfinv(), tensor([float('-inf'), float('inf')])))
# test nan
self.assertEqual(tensor([-2, 2]).erfinv(), tensor([float('nan'), float('nan')]))
checkType(torch.FloatTensor)
checkType(torch.DoubleTensor)
评论列表
文章目录