def test_local_sampling_dot_csr():
if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.")
mode = theano.compile.mode.get_default_mode()
mode = mode.including("specialize", "local_sampling_dot_csr")
for sp_format in ['csr']: # Not implemented for other format
inputs = [tensor.matrix(),
tensor.matrix(),
getattr(theano.sparse, sp_format + '_matrix')()]
f = theano.function(inputs,
sparse.sampling_dot(*inputs),
mode=mode)
if theano.config.blas.ldflags:
assert not any(isinstance(node.op, sparse.SamplingDot) for node
in f.maker.fgraph.toposort())
else:
# SamplingDotCSR's C implementation needs blas, so it should not
# be inserted
assert not any(isinstance(node.op, sparse.opt.SamplingDotCSR) for node
in f.maker.fgraph.toposort())
评论列表
文章目录