def test_comparison_ops(self):
x = torch.randn(5, 5)
y = torch.randn(5, 5)
eq = x == y
for idx in iter_indices(x):
self.assertIs(x[idx] == y[idx], eq[idx] == 1)
ne = x != y
for idx in iter_indices(x):
self.assertIs(x[idx] != y[idx], ne[idx] == 1)
lt = x < y
for idx in iter_indices(x):
self.assertIs(x[idx] < y[idx], lt[idx] == 1)
le = x <= y
for idx in iter_indices(x):
self.assertIs(x[idx] <= y[idx], le[idx] == 1)
gt = x > y
for idx in iter_indices(x):
self.assertIs(x[idx] > y[idx], gt[idx] == 1)
ge = x >= y
for idx in iter_indices(x):
self.assertIs(x[idx] >= y[idx], ge[idx] == 1)
评论列表
文章目录