test_vm.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_c_thunks():
    a = tensor.scalars('a')
    b, c = tensor.vectors('bc')
    cases = [False]
    if theano.config.cxx:
        cases.append(True)
    for c_thunks in cases:
        f = function([a, b, c], ifelse(a, a * b, b * c),
                     mode=Mode(
                         optimizer=None,
                         linker=vm.VM_Linker(c_thunks=c_thunks,
                                             use_cloop=False)))
        f(1, [2], [3, 2])
        from nose.tools import assert_raises
        assert_raises(ValueError, f, 0, [2], [3, 4])
        assert any([hasattr(t, 'cthunk') for t in f.fn.thunks]) == c_thunks
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号