ops_test.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def test_invalid(self):
    scalar_lt = core.LabeledTensor(array_ops.ones(()), [])
    x_lt = core.LabeledTensor(array_ops.ones((2,)), ['x'])
    x2_lt = core.LabeledTensor(array_ops.ones((3,)), ['x'])
    y_lt = core.LabeledTensor(array_ops.ones((3,)), ['y'])
    xy_lt = core.LabeledTensor(array_ops.ones((2, 3)), ['x', 'y'])
    xyz_lt = core.LabeledTensor(array_ops.ones((2, 3, 1)), ['x', 'y', 'z'])

    with self.assertRaisesRegexp(ValueError, 'inputs with at least rank'):
      ops.matmul(x_lt, scalar_lt)

    with self.assertRaises(NotImplementedError):
      ops.matmul(x_lt, xyz_lt)

    with self.assertRaisesRegexp(ValueError, 'exactly one axis in common'):
      ops.matmul(x_lt, y_lt)

    with self.assertRaises(NotImplementedError):
      ops.matmul(xy_lt, xy_lt)

    with self.assertRaisesRegexp(ValueError, 'does not match'):
      ops.matmul(x_lt, x2_lt)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号