def test_broadcast2(self):
# test switch(cst, vector, matrix)
# This case is not optimized for now.
x = theano.tensor.vector('x', dtype='int32')
y = theano.tensor.matrix('y', dtype='int64')
z = theano.tensor.switch(1, x, y)
f = theano.function([x, y], z, mode=self.mode)
assert len([node.op for node in f.maker.fgraph.toposort() if
isinstance(node.op, theano.tensor.Elemwise) and
not isinstance(node.op.scalar_op, theano.scalar.basic.Cast)]) == 0
vx = numpy.array([4, 5, 6], dtype='int32')
vy = numpy.array([[7, 8, 9], [10, 11, 12]], dtype='int64')
assert numpy.all(f(vx, vy) == vx)
z = theano.tensor.switch(0, x, y)
f = theano.function([x, y], z, mode=self.mode)
assert len([node.op for node in f.maker.fgraph.toposort() if
isinstance(node.op, theano.tensor.Elemwise)]) == 0
vx = numpy.array([4, 5, 6], dtype='int32')
vy = numpy.array([[7, 8, 9], [10, 11, 12]], dtype='int64')
assert numpy.all(f(vx, vy) == vy)
评论列表
文章目录