def test_remove_useless_inputs2(self):
raise SkipTest("Optimization temporarily disabled")
x1 = tensor.vector('x1')
x2 = tensor.vector('x2')
y1 = tensor.vector('y1')
y2 = tensor.vector('y2')
c = tensor.iscalar('c')
z = ifelse(c, (x1, x1, x1, x2, x2), (y1, y1, y2, y2, y2))
f = theano.function([c, x1, x2, y1, y2], z)
ifnode = [x for x in f.maker.fgraph.toposort()
if isinstance(x.op, IfElse)][0]
assert len(ifnode.outputs) == 3
评论列表
文章目录