def test_topk(self):
def topKViaSort(t, k, dim, dir):
sorted, indices = t.sort(dim, dir)
return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)
def compareTensors(t, res1, ind1, res2, ind2, dim):
# Values should be exactly equivalent
self.assertEqual(res1, res2, 0)
# Indices might differ based on the implementation, since there is
# no guarantee of the relative order of selection
if not ind1.eq(ind2).all():
# To verify that the indices represent equivalent elements,
# gather from the input using the topk indices and compare against
# the sort indices
vals = t.gather(dim, ind2)
self.assertEqual(res1, vals, 0)
def compare(t, k, dim, dir):
topKVal, topKInd = t.topk(k, dim, dir, True)
sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)
t = torch.rand(random.randint(1, SIZE),
random.randint(1, SIZE),
random.randint(1, SIZE))
for _kTries in range(3):
for _dimTries in range(3):
for transpose in (True, False):
for dir in (True, False):
testTensor = t
if transpose:
dim1 = random.randrange(t.ndimension())
dim2 = dim1
while dim1 == dim2:
dim2 = random.randrange(t.ndimension())
testTensor = t.transpose(dim1, dim2)
dim = random.randrange(testTensor.ndimension())
k = random.randint(1, testTensor.size(dim))
compare(testTensor, k, dim, dir)
评论列表
文章目录