def test_dot22():
for dtype1 in ['float32', 'float64', 'complex64', 'complex128']:
a = T.matrix(dtype=dtype1)
for dtype2 in ['float32', 'float64', 'complex64', 'complex128']:
b = T.matrix(dtype=dtype2)
f = theano.function([a, b], T.dot(a, b), mode=mode_blas_opt)
topo = f.maker.fgraph.toposort()
if dtype1 == dtype2:
assert _dot22 in [x.op for x in topo], (dtype1, dtype2)
else:
check = [isinstance(x.op, T.Dot) for x in topo]
assert any(check), (dtype1, dtype2)
rng = numpy.random.RandomState(unittest_tools.fetch_seed())
def cmp(a_shp, b_shp):
av = rng.uniform(size=a_shp).astype(dtype1)
bv = rng.uniform(size=b_shp).astype(dtype2)
f(av, bv)
cmp((3, 4), (4, 5))
cmp((0, 4), (4, 5))
cmp((3, 0), (0, 5))
cmp((3, 4), (4, 0))
cmp((0, 4), (4, 0))
cmp((0, 0), (0, 0))
评论列表
文章目录