test_jit.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def test_lstm_fusion(self):
        input = Variable(torch.randn(3, 10).cuda())
        hx = Variable(torch.randn(3, 20).cuda())
        cx = Variable(torch.randn(3, 20).cuda())
        module = nn.LSTMCell(10, 20).cuda()  # Just to allocate weights with correct sizes

        def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
            hx, cx = hidden
            gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)

            ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
            ingate = F.sigmoid(ingate)
            forgetgate = F.sigmoid(forgetgate)
            cellgate = F.tanh(cellgate)
            outgate = F.sigmoid(outgate)

            cy = (forgetgate * cx) + (ingate * cellgate)
            hy = outgate * F.tanh(cy)
            return hy, cy

        trace, _ = torch.jit.trace(LSTMCell, (input, (hx, cx)) + tuple(module.parameters()))
        torch._C._jit_pass_lint(trace)
        torch._C._jit_pass_onnx(trace)
        torch._C._jit_pass_lint(trace)
        torch._C._jit_pass_fuse(trace)
        torch._C._jit_pass_lint(trace)
        self.assertExpected(str(trace))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号