def test_matrix_matrix(self):
xy_lt = core.LabeledTensor(
array_ops.reshape(math_ops.range(6), (2, 3)), ['x', 'y'])
yz_lt = core.LabeledTensor(
array_ops.reshape(math_ops.range(12), (3, 4)), ['y', 'z'])
matmul_lt = ops.matmul(xy_lt, yz_lt)
golden_lt = core.LabeledTensor(
math_ops.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z'])
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
transpose = lambda x: core.transpose(x, list(x.axes.keys())[::-1])
matmul_lt = ops.matmul(xy_lt, transpose(yz_lt))
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
matmul_lt = ops.matmul(transpose(xy_lt), yz_lt)
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
matmul_lt = ops.matmul(transpose(xy_lt), transpose(yz_lt))
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
matmul_lt = ops.matmul(yz_lt, xy_lt)
self.assertLabeledTensorsEqual(matmul_lt, transpose(golden_lt))
ops_test.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录