test_vm.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_partial_function_with_updates():

    def check_updates(linker_name):
        x = tensor.lscalar('input')
        y = theano.shared(numpy.asarray(1, 'int64'), name='global')
        f = theano.function([x], [x, x + 34], updates=[(y, x + 1)], mode=Mode(
            optimizer=None, linker=linker_name))
        g = theano.function([x], [x - 6], updates=[(y, y + 3)], mode=Mode(
            optimizer=None, linker=linker_name))

        assert f(3, output_subset=[]) == []
        assert y.get_value() == 4
        assert g(30, output_subset=[0]) == [24]
        assert g(40, output_subset=[]) == []
        assert y.get_value() == 10

    check_updates(vm.VM_Linker(allow_partial_eval=True, use_cloop=False))
    check_updates('cvm')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号