ops_test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def test_matrix_matrix(self):
    xy_lt = core.LabeledTensor(
        array_ops.reshape(math_ops.range(6), (2, 3)), ['x', 'y'])
    yz_lt = core.LabeledTensor(
        array_ops.reshape(math_ops.range(12), (3, 4)), ['y', 'z'])

    matmul_lt = ops.matmul(xy_lt, yz_lt)
    golden_lt = core.LabeledTensor(
        math_ops.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z'])
    self.assertLabeledTensorsEqual(matmul_lt, golden_lt)

    transpose = lambda x: core.transpose(x, list(x.axes.keys())[::-1])

    matmul_lt = ops.matmul(xy_lt, transpose(yz_lt))
    self.assertLabeledTensorsEqual(matmul_lt, golden_lt)

    matmul_lt = ops.matmul(transpose(xy_lt), yz_lt)
    self.assertLabeledTensorsEqual(matmul_lt, golden_lt)

    matmul_lt = ops.matmul(transpose(xy_lt), transpose(yz_lt))
    self.assertLabeledTensorsEqual(matmul_lt, golden_lt)

    matmul_lt = ops.matmul(yz_lt, xy_lt)
    self.assertLabeledTensorsEqual(matmul_lt, transpose(golden_lt))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号