def test_dtype_mismatch(self):
rng = numpy.random.RandomState(utt.fetch_seed())
data = rng.rand(5).astype(self.dtype)
x = self.shared(data)
y = tensor.cast(x * 10, 'int8')
cond = theano.tensor.iscalar('cond')
self.assertRaises(TypeError, ifelse, cond, x, y)
self.assertRaises(TypeError, ifelse, cond, y, x)
评论列表
文章目录