def test_upcasting_scalar_nogemm():
# Test that the optimization does not crash when the scale has an incorrect
# dtype, and forces upcasting of the result
v = T.fmatrix('v')
w = T.fmatrix('w')
t = T.fmatrix('t')
alpha = T.dscalar('a')
rval = T.dot(w, v) * alpha + t
f = theano.function([w, v, t, alpha], rval)
t = f.maker.fgraph.toposort()
assert numpy.sum([isinstance(n.op, Gemm) for n in t]) == 0
#theano.printing.debugprint(f, print_type=True)
v = T.fmatrix('v')
w = T.fmatrix('w')
t = T.fmatrix('t')
alpha = T.cscalar('a')
on_opt_error = config.on_opt_error
try:
config.on_opt_error = 'raise'
rval = T.dot(w, v) * alpha + t
f = theano.function([w, v, t, alpha], rval)
finally:
config.on_opt_error = on_opt_error
t = f.maker.fgraph.toposort()
assert numpy.sum([isinstance(n.op, Gemm) for n in t]) == 0
#theano.printing.debugprint(f, print_type=True)
评论列表
文章目录