test_nn.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def _test_variable_sequence(self, cuda):
        def pad(var, length):
            if var.size(0) == length:
                return var
            return torch.cat([var, Variable(var.data.new(length - var.size(0), *var.size()[1:]).zero_())])

        lengths = [10, 10, 6, 2, 2, 1, 1]
        max_length = lengths[0]
        x_leaf = Variable(torch.randn(max_length, len(lengths), 3), requires_grad=True)
        lstm = nn.LSTM(3, 4, bidirectional=True, num_layers=2)
        lstm2 = deepcopy(lstm)
        if cuda:
            x = x_leaf.cuda()
            lstm.cuda()
            lstm2.cuda()
        else:
            x = x_leaf

        # Compute sequences separately
        seq_outs = []
        seq_hiddens = []
        for i, l in enumerate(lengths):
            out, hid = lstm2(x[:l, i:i + 1])
            out_pad = pad(out, max_length)
            seq_outs.append(out_pad)
            seq_hiddens.append(hid)
        seq_out = torch.cat(seq_outs, 1)
        seq_hidden = tuple(torch.cat(hids, 1) for hids in zip(*seq_hiddens))

        # Use packed format
        packed = rnn_utils.pack_padded_sequence(x, lengths)
        packed_out, packed_hidden = lstm(packed)
        unpacked, unpacked_len = rnn_utils.pad_packed_sequence(packed_out)

        # Check forward
        self.assertEqual(packed_hidden, seq_hidden)
        self.assertEqual(unpacked, seq_out)
        self.assertEqual(unpacked_len, lengths)

        # Check backward
        seq_out.sum().backward()
        grad_x = x_leaf.grad.data.clone()
        x_leaf.grad.data.zero_()
        unpacked.sum().backward()

        self.assertEqual(x_leaf.grad.data, grad_x)
        for p1, p2 in zip(lstm.parameters(), lstm2.parameters()):
            self.assertEqual(p1.grad, p2.grad)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号