def test_prod_without_zeros_custom_dtype(self):
"""
Test ability to provide your own output dtype for a ProdWithoutZeros().
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [], [0], [1], [0, 1]]
idx = 0
for input_dtype in imap(str, theano.scalar.all_types):
x = tensor.matrix(dtype=input_dtype)
for output_dtype in imap(str, theano.scalar.all_types):
axis = axes[idx % len(axes)]
prod_woz_var = ProdWithoutZeros(
axis=axis, dtype=output_dtype)(x)
assert prod_woz_var.dtype == output_dtype
idx += 1
if ('complex' in output_dtype or
'complex' in input_dtype):
continue
f = theano.function([x], prod_woz_var)
data = numpy.random.rand(2, 3) * 3
data = data.astype(input_dtype)
f(data)
评论列表
文章目录