def test_gemm_opt_double_gemm():
"""This is the pattern that shows up in the autoencoder"""
X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
R, S, c = T.matrix(), T.matrix(), T.scalar()
just_gemm([X, Y, Z, a, b, R, S, c],
[Z * c + a * T.dot(X, Y) + b * T.dot(R, S).T],
ishapes=[(4, 3), (3, 5), (4, 5), (), (), (5, 9), (9, 4), ()],
expected_nb_gemm=2)
ishapes = [(4, 3), (3, 5), (4, 5), (), (), (5, 9), (9, 4), ()]
i = [X, Y, Z, a, b, R, S, c]
o = [(a * T.dot(X, Y)
+ gemm_inplace(Z, b, S.T, R.T, T.constant(1.0).astype(config.floatX)))]
try:
f = inplace_func([In(ii, mutable=True) for ii in i], o,
mode='FAST_RUN', on_unused_input='ignore')
for node in f.maker.fgraph.apply_nodes:
if isinstance(node.op, T.Dot):
raise Failure('dot in graph')
if node.op == _dot22:
raise Failure('_dot22 in graph')
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
on_unused_input='ignore')
# for node in g.maker.fgraph.apply_nodes:
# if node.op == gemm_inplace: raise Failure('gemm_inplace in graph')
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r0 = f(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r1 = g(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
eps = 1.0e-8
if config.floatX == 'float32':
eps = 1.0e-6
if max_abs_err > eps:
raise Failure(
'GEMM is computing the wrong output. max_rel_err =',
max_abs_err)
except Failure:
for node in f.maker.fgraph.toposort():
print('GRAPH', node)
raise
评论列表
文章目录