def test_inplace0():
# should fail to insert gemm_inplace because gemm_inplace would
# create cycles
X, Y, Z, a, b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar(
'a'), T.scalar('b')
R, S, c = T.matrix('R'), T.matrix('S'), T.scalar('c')
f = inplace_func([Z, b, R, S],
[Z * (Z + b * T.dot(R, S).T)], mode='FAST_RUN')
if (gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]):
print(pp(f.maker.fgraph.outputs[0]))
raise Failure('gemm_inplace in graph')
assert gemm_no_inplace in [n.op for n in f.maker.fgraph.apply_nodes]
# gemm_inplace should be inserted here, to work in-place on Z*c
f = inplace_func([X, Y, Z, a, b, R, S, c],
[Z * (c * Z + a * T.dot(X, Y) + b * T.dot(R, S).T)],
mode='FAST_RUN')
if (not gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]):
theano.printing.debugprint(f)
raise Failure('no gemm_inplace in graph')
评论列表
文章目录