test_scan.py 文件源码

python
阅读 45 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_grad_dtype_change(self):
        x = tensor.fscalar('x')
        y = tensor.fscalar('y')
        c = tensor.iscalar('c')

        def inner_fn(cond, x, y):
            new_cond = tensor.cast(tensor.switch(cond, x, y), 'int32')
            new_x = tensor.switch(cond, tensor.nnet.sigmoid(y * x), x)
            new_y = tensor.switch(cond, y, tensor.nnet.sigmoid(x))
            return new_cond, new_x, new_y

        values, _ = theano.scan(
            inner_fn,
            outputs_info=[c, x, y],
            n_steps=10,
            truncate_gradient=-1,
            go_backwards=False)
        gX, gY = tensor.grad(values[1].sum(), [x, y])
        f = theano.function([c, x, y], [gX, gY],
                            allow_input_downcast=True)
        # Check for runtime errors
        f(numpy.int32(0), numpy.float32(1.), numpy.float32(.5))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号