test_ifelse.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_multiple_out_grad(self):
        # Tests that we can compute the gradients through lazy if
        x1 = tensor.vector('x1')
        x2 = tensor.vector('x2')
        y1 = tensor.vector('y1')
        y2 = tensor.vector('y2')
        c = tensor.iscalar('c')
        z = ifelse(c, (x1, x2), (y1, y2))
        grads = tensor.grad(z[0].sum() + z[1].sum(),
                            [x1, x2, y1, y2])

        f = theano.function([c, x1, x2, y1, y2], grads)
        rng = numpy.random.RandomState(utt.fetch_seed())

        lens = [rng.randint(200) for i in range(4)]
        values = [numpy.asarray(rng.uniform(size=(l,)), theano.config.floatX)
                  for l in lens]
        outs_1 = f(1, *values)
        assert all([x.shape[0] == y for x, y in zip(outs_1, lens)])
        assert numpy.all(outs_1[0] == 1.)
        assert numpy.all(outs_1[1] == 1.)
        assert numpy.all(outs_1[2] == 0.)
        assert numpy.all(outs_1[3] == 0.)

        outs_0 = f(0, *values)
        assert all([x.shape[0] == y for x, y in zip(outs_1, lens)])
        assert numpy.all(outs_0[0] == 0.)
        assert numpy.all(outs_0[1] == 0.)
        assert numpy.all(outs_0[2] == 1.)
        assert numpy.all(outs_0[3] == 1.)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号