def test_invalid(self):
scalar_lt = core.LabeledTensor(array_ops.ones(()), [])
x_lt = core.LabeledTensor(array_ops.ones((2,)), ['x'])
x2_lt = core.LabeledTensor(array_ops.ones((3,)), ['x'])
y_lt = core.LabeledTensor(array_ops.ones((3,)), ['y'])
xy_lt = core.LabeledTensor(array_ops.ones((2, 3)), ['x', 'y'])
xyz_lt = core.LabeledTensor(array_ops.ones((2, 3, 1)), ['x', 'y', 'z'])
with self.assertRaisesRegexp(ValueError, 'inputs with at least rank'):
ops.matmul(x_lt, scalar_lt)
with self.assertRaises(NotImplementedError):
ops.matmul(x_lt, xyz_lt)
with self.assertRaisesRegexp(ValueError, 'exactly one axis in common'):
ops.matmul(x_lt, y_lt)
with self.assertRaises(NotImplementedError):
ops.matmul(xy_lt, xy_lt)
with self.assertRaisesRegexp(ValueError, 'does not match'):
ops.matmul(x_lt, x2_lt)
ops_test.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录