test_opt.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_recursive_lift(self):
        v = T.vector(dtype="float64")
        m = T.matrix(dtype="float64")
        out = ((v + 42) * (m + 84)).T
        g = FunctionGraph([v, m], [out])
        init_str_g = ("[InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}"
                      "(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}"
                      "(<TensorType(float64, vector)>, "
                      "InplaceDimShuffle{x}(TensorConstant{42}))), "
                      "Elemwise{add,no_inplace}"
                      "(<TensorType(float64, matrix)>, "
                      "InplaceDimShuffle{x,x}(TensorConstant{84}))))]")
        self.assertTrue(str(g) == init_str_g)
        new_out = local_dimshuffle_lift.transform(g.outputs[0].owner)[0]
        new_g = FunctionGraph(g.inputs, [new_out])
        opt_str_g = ("[Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
                     "(InplaceDimShuffle{0,x}(<TensorType(float64, vector)>), "
                     "InplaceDimShuffle{x,x}(TensorConstant{42})), "
                     "Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}"
                     "(<TensorType(float64, matrix)>), "
                     "InplaceDimShuffle{x,x}(TensorConstant{84})))]")
        self.assertTrue(str(new_g) == opt_str_g)
        # Check stacktrace was copied over correctly after opt was applied
        self.assertTrue(check_stack_trace(new_g, ops_to_check='all'))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号