test_jit.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def test_python_ir(self):
        x = Variable(torch.Tensor([0.4]), requires_grad=True)
        y = Variable(torch.Tensor([0.7]), requires_grad=True)

        def doit(x, y):
            return torch.sigmoid(torch.tanh(x * (x + y)))

        traced, _ = torch.jit.trace(doit, (x, y))
        g = torch._C._jit_get_graph(traced)
        g2 = torch._C.Graph()
        g_to_g2 = {}
        for node in g.inputs():
            g_to_g2[node] = g2.addInput()
        for node in g.nodes():
            if node.kind() == "PythonOp":
                n_ = g2.create(node.pyname(),
                               [g_to_g2[i] for i in node.inputs()]) \
                    .setType(node.typeOption()) \
                    .s_("note", "from_pyop") \
                    .i_("some_value", len(node.scalar_args()))
                assert(n_.i("some_value") == len(node.scalar_args()))
            else:
                n_ = g2.createClone(node, lambda x: g_to_g2[x])
                assert(n_.kindOf("Offset") == "i")

            g_to_g2[node] = g2.appendNode(n_)

        for node in g.outputs():
            g2.registerOutput(g_to_g2[node])

        t_node = g2.create("TensorTest").t_("a", torch.ones([2, 2]))
        assert(t_node.attributeNames() == ["a"])
        g2.appendNode(t_node)
        assert(torch.equal(torch.ones([2, 2]), t_node.t("a")))
        self.assertExpected(str(g2))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号