def test_dot22scalar_cast():
"""
Test that in `dot22_to_dot22scalar` we properly cast integers to floats.
"""
# Note that this test was failing before d5ff6904.
A = T.dmatrix()
for scalar_int_type in T.int_dtypes:
y = T.scalar(dtype=scalar_int_type)
f = theano.function([A, y], T.dot(A, A) * y, mode=mode_blas_opt)
assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()]
A = T.fmatrix()
for scalar_int_type in T.int_dtypes:
y = T.scalar(dtype=scalar_int_type)
f = theano.function([A, y], T.dot(A, A) * y, mode=mode_blas_opt)
if scalar_int_type in ['int32', 'int64']:
assert _dot22 in [x.op for x in f.maker.fgraph.toposort()]
else:
assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()]
评论列表
文章目录