test_blas.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_gemm_opt0():
    """Many subgraphs whose dots can be eliminated"""
    X, Y, Z, a, b = XYZab()

    just_gemm([X, Y, Z, a, b], [T.dot(X, Y) * a + Z * b])
    just_gemm([X, Y, Z, a, b], [a * T.dot(X, Y) + b * Z])
    just_gemm([X, Y, Z, a, b], [b * Z + a * T.dot(X, Y)])
    just_gemm([X, Y, Z, a, b], [T.dot(X, Y) * a - Z * b])
    just_gemm([X, Y, Z, a, b], [a * T.dot(X, Y) - b * Z])
    just_gemm([X, Y, Z, a, b], [b * Z - a * T.dot(X, Y)])

    # with transposes (transposes should be pushed through dot in canonicalize)
    just_gemm([X, Y, Z, a, b], [b * Z.T - a * T.dot(Y.T, X.T)])
    just_gemm([X, Y, Z, a, b], [b * Z.T + a * b * T.dot(X, Y).T])
    just_gemm([X, Y, Z, a, b], [b * Z + a * T.dot(X, Y).T],
            ishapes=[(5, 3), (3, 4), (4, 5), (), ()])

    # with N multiplications instead of just one
    just_gemm([X, Y, Z, a, b], [(b * b) * Z * a + (a * a) * T.dot(X, Y) * b])
    just_gemm([X, Y, Z, a, b], [Z + T.dot(X, Y)])
    just_gemm([X, Y, Z, a, b], [Z * b + T.dot(X, Y)])
    just_gemm([X, Y, Z, a, b], [Z + a * b * a * T.dot(X, Y)])
    just_gemm([X, Y, Z, a, b], [(b * b) * Z * a - (a * a) * T.dot(X, Y) * b])
    just_gemm([X, Y, Z, a, b], [Z - T.dot(X, Y)])
    just_gemm([X, Y, Z, a, b], [Z * b - T.dot(X, Y)])
    just_gemm([X, Y, Z, a, b], [Z - a * b * a * T.dot(X, Y)])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号