def test_dnn_tag():
"""
Test that if cudnn isn't avail we crash and that if it is avail, we use it.
"""
x = T.ftensor4()
old = theano.config.on_opt_error
theano.config.on_opt_error = "raise"
sio = StringIO()
handler = logging.StreamHandler(sio)
logging.getLogger('theano.compile.tests.test_dnn').addHandler(handler)
# Silence original handler when intentionnally generating warning messages
logging.getLogger('theano').removeHandler(theano.logging_default_handler)
raised = False
try:
f = theano.function(
[x],
pool_2d(x, ds=(2, 2), ignore_border=True),
mode=mode_with_gpu.including("cudnn"))
except (AssertionError, RuntimeError):
assert not cuda.dnn.dnn_available()
raised = True
finally:
theano.config.on_opt_error = old
logging.getLogger(
'theano.compile.tests.test_dnn').removeHandler(handler)
logging.getLogger('theano').addHandler(theano.logging_default_handler)
if not raised:
assert cuda.dnn.dnn_available()
assert any([isinstance(n.op, cuda.dnn.GpuDnnPool)
for n in f.maker.fgraph.toposort()])
评论列表
文章目录