def test_kthvalue(self):
SIZE = 50
x = torch.rand(SIZE, SIZE, SIZE)
x0 = x.clone()
k = random.randint(1, SIZE)
res1val, res1ind = torch.kthvalue(x, k, False)
res2val, res2ind = torch.sort(x)
self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
# test use of result tensors
k = random.randint(1, SIZE)
res1val = torch.Tensor()
res1ind = torch.LongTensor()
torch.kthvalue(x, k, False, out=(res1val, res1ind))
res2val, res2ind = torch.sort(x)
self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
# test non-default dim
k = random.randint(1, SIZE)
res1val, res1ind = torch.kthvalue(x, k, 0, False)
res2val, res2ind = torch.sort(x, 0)
self.assertEqual(res1val, res2val[k - 1], 0)
self.assertEqual(res1ind, res2ind[k - 1], 0)
# non-contiguous
y = x.narrow(1, 0, 1)
y0 = y.contiguous()
k = random.randint(1, SIZE)
res1val, res1ind = torch.kthvalue(y, k)
res2val, res2ind = torch.kthvalue(y0, k)
self.assertEqual(res1val, res2val, 0)
self.assertEqual(res1ind, res2ind, 0)
# check that the input wasn't modified
self.assertEqual(x, x0, 0)
# simple test case (with repetitions)
y = torch.Tensor((3, 5, 4, 1, 1, 5))
self.assertEqual(torch.kthvalue(y, 3)[0], torch.Tensor((3,)), 0)
self.assertEqual(torch.kthvalue(y, 2)[0], torch.Tensor((1,)), 0)
评论列表
文章目录