def _testSelection(self, torchfn, mathfn):
# contiguous
m1 = torch.randn(100, 100)
res1 = torchfn(m1)
res2 = m1[0, 0]
for i, j in iter_indices(m1):
res2 = mathfn(res2, m1[i, j])
self.assertEqual(res1, res2)
# non-contiguous
m1 = torch.randn(10, 10, 10)
m2 = m1[:, 4]
res1 = torchfn(m2)
res2 = m2[0, 0]
for i, j in iter_indices(m2):
res2 = mathfn(res2, m2[i][j])
self.assertEqual(res1, res2)
# with indices
m1 = torch.randn(100, 100)
res1val, res1ind = torchfn(m1, 1, False)
res2val = m1[:, 0:1].clone().squeeze()
res2ind = res1ind.clone().fill_(0)
for i, j in iter_indices(m1):
if mathfn(res2val[i], m1[i, j]) != res2val[i]:
res2val[i] = m1[i, j]
res2ind[i] = j
maxerr = 0
for i in range(res1val.size(0)):
maxerr = max(maxerr, abs(res1val[i] - res2val[i]))
self.assertEqual(res1ind[i], res2ind[i])
self.assertLessEqual(abs(maxerr), 1e-5)
# NaNs
for index in (0, 4, 99):
m1 = torch.randn(100)
m1[index] = float('nan')
res1val, res1ind = torch.max(m1, 0)
self.assertNotEqual(res1val[0], res1val[0])
self.assertEqual(res1ind[0], index)
res1val = torchfn(m1)
self.assertNotEqual(res1val, res1val)
评论列表
文章目录