def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
max_graphlen=0, expected_nb_gemm=1):
try:
f = inplace_func(
[In(ii, mutable=True, allow_downcast=True) for ii in i],
o,
mode='FAST_RUN',
on_unused_input='ignore')
nb_gemm = 0
for node in f.maker.fgraph.apply_nodes:
if isinstance(node.op, T.Dot):
raise Failure('dot not changed to gemm_inplace in graph')
if node.op == _dot22:
raise Failure('_dot22 not changed to gemm_inplace in graph')
if node.op == gemm_inplace:
nb_gemm += 1
assert nb_gemm == expected_nb_gemm, (nb_gemm, expected_nb_gemm)
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
allow_input_downcast=True, on_unused_input='ignore')
for node in g.maker.fgraph.apply_nodes:
if node.op == gemm_inplace:
raise Exception('gemm_inplace in original graph')
graphlen = len(f.maker.fgraph.toposort())
if max_graphlen and (graphlen <= max_graphlen):
# theano.printing.debugprint(f)
assert False, 'graphlen=%i>%i' % (graphlen, max_graphlen)
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r0 = f(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r1 = g(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
eps = 1.0e-8
if config.floatX == 'float32':
eps = 1.0e-6
if max_abs_err > eps:
raise Failure('GEMM is computing the wrong output. max_rel_err =',
max_abs_err)
except Failure:
for node in f.maker.fgraph.toposort():
print('GRAPH', node)
raise
评论列表
文章目录