def ___test_infer_shape_tuple(self):
a = tensor.tensor3(dtype='int32')
b = tensor.tensor3(dtype='int32')
c = tensor.tensor3(dtype='int32')
A = numpy.asarray([1, 0], dtype='int32').reshape((2, 1, 1))
B = numpy.asarray(numpy.random.rand(1, 4, 1), dtype='int32')
C = numpy.asarray(numpy.random.rand(1, 1, 7), dtype='int32')
f = function([a, b, c], choose(a, (b, c)))
shape = (2, 4, 7)
assert numpy.allclose(f(A, B, C).shape, shape)
self._compile_and_check([a, b, c], # theano.function inputs
[self.op(a, (b, c))], # theano.function outputs
# Always use not square matrix!
# inputs data
[A, B, C],
# Op that should be removed from the graph.
self.op_class)
评论列表
文章目录