def test_mean_default_dtype(self):
"""
Test the default dtype of a mean().
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [], [0], [1], [0, 1]]
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)]
x = tensor.matrix(dtype=dtype)
m = x.mean(axis=axis)
if dtype in tensor.discrete_dtypes and axis != []:
assert m.dtype == 'float64'
else:
assert m.dtype == dtype, (m, m.dtype, dtype)
f = theano.function([x], m)
data = numpy.random.rand(3, 4) * 10
data = data.astype(dtype)
f(data)
评论列表
文章目录