def test_gemm_opt0():
"""Many subgraphs whose dots can be eliminated"""
X, Y, Z, a, b = XYZab()
just_gemm([X, Y, Z, a, b], [T.dot(X, Y) * a + Z * b])
just_gemm([X, Y, Z, a, b], [a * T.dot(X, Y) + b * Z])
just_gemm([X, Y, Z, a, b], [b * Z + a * T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [T.dot(X, Y) * a - Z * b])
just_gemm([X, Y, Z, a, b], [a * T.dot(X, Y) - b * Z])
just_gemm([X, Y, Z, a, b], [b * Z - a * T.dot(X, Y)])
# with transposes (transposes should be pushed through dot in canonicalize)
just_gemm([X, Y, Z, a, b], [b * Z.T - a * T.dot(Y.T, X.T)])
just_gemm([X, Y, Z, a, b], [b * Z.T + a * b * T.dot(X, Y).T])
just_gemm([X, Y, Z, a, b], [b * Z + a * T.dot(X, Y).T],
ishapes=[(5, 3), (3, 4), (4, 5), (), ()])
# with N multiplications instead of just one
just_gemm([X, Y, Z, a, b], [(b * b) * Z * a + (a * a) * T.dot(X, Y) * b])
just_gemm([X, Y, Z, a, b], [Z + T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [Z * b + T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [Z + a * b * a * T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [(b * b) * Z * a - (a * a) * T.dot(X, Y) * b])
just_gemm([X, Y, Z, a, b], [Z - T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [Z * b - T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [Z - a * b * a * T.dot(X, Y)])
评论列表
文章目录