def test_python_ir(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
traced, _ = torch.jit.trace(doit, (x, y))
g = torch._C._jit_get_graph(traced)
g2 = torch._C.Graph()
g_to_g2 = {}
for node in g.inputs():
g_to_g2[node] = g2.addInput()
for node in g.nodes():
if node.kind() == "PythonOp":
n_ = g2.create(node.pyname(),
[g_to_g2[i] for i in node.inputs()]) \
.setType(node.typeOption()) \
.s_("note", "from_pyop") \
.i_("some_value", len(node.scalar_args()))
assert(n_.i("some_value") == len(node.scalar_args()))
else:
n_ = g2.createClone(node, lambda x: g_to_g2[x])
assert(n_.kindOf("Offset") == "i")
g_to_g2[node] = g2.appendNode(n_)
for node in g.outputs():
g2.registerOutput(g_to_g2[node])
t_node = g2.create("TensorTest").t_("a", torch.ones([2, 2]))
assert(t_node.attributeNames() == ["a"])
g2.appendNode(t_node)
assert(torch.equal(torch.ones([2, 2]), t_node.t("a")))
self.assertExpected(str(g2))
评论列表
文章目录