def _test_cop(self, torchfn, mathfn):
def reference_implementation(res2):
for i, j in iter_indices(sm1):
idx1d = i * sm1.size(0) + j
res2[i,j] = mathfn(sm1[i,j], sm2[idx1d])
return res2
# contiguous
m1 = torch.randn(10, 10, 10)
m2 = torch.randn(10, 10 * 10)
sm1 = m1[4]
sm2 = m2[4]
res1 = torchfn(sm1, sm2)
res2 = reference_implementation(res1.clone())
self.assertEqual(res1, res2)
# non-contiguous
m1 = torch.randn(10, 10, 10)
m2 = torch.randn(10 * 10, 10 * 10)
sm1 = m1[:,4]
sm2 = m2[:,4]
res1 = torchfn(sm1, sm2)
res2 = reference_implementation(res1.clone())
self.assertEqual(res1, res2)
评论列表
文章目录