def _testMath(self, torchfn, mathfn):
size = (10, 5)
# contiguous
m1 = torch.randn(*size)
res1 = torchfn(m1[4])
res2 = res1.clone().zero_()
for i, v in enumerate(m1[4]):
res2[i] = mathfn(v)
self.assertEqual(res1, res2)
# non-contiguous
m1 = torch.randn(*size)
res1 = torchfn(m1[:,4])
res2 = res1.clone().zero_()
for i, v in enumerate(m1[:,4]):
res2[i] = mathfn(v)
self.assertEqual(res1, res2)
评论列表
文章目录