def _test_variable_sequence(self, cuda):
def pad(var, length):
if var.size(0) == length:
return var
return torch.cat([var, Variable(var.data.new(length - var.size(0), *var.size()[1:]).zero_())])
lengths = [10, 10, 6, 2, 2, 1, 1]
max_length = lengths[0]
x_leaf = Variable(torch.randn(max_length, len(lengths), 3), requires_grad=True)
lstm = nn.LSTM(3, 4, bidirectional=True, num_layers=2)
lstm2 = deepcopy(lstm)
if cuda:
x = x_leaf.cuda()
lstm.cuda()
lstm2.cuda()
else:
x = x_leaf
# Compute sequences separately
seq_outs = []
seq_hiddens = []
for i, l in enumerate(lengths):
out, hid = lstm2(x[:l, i:i + 1])
out_pad = pad(out, max_length)
seq_outs.append(out_pad)
seq_hiddens.append(hid)
seq_out = torch.cat(seq_outs, 1)
seq_hidden = tuple(torch.cat(hids, 1) for hids in zip(*seq_hiddens))
# Use packed format
packed = rnn_utils.pack_padded_sequence(x, lengths)
packed_out, packed_hidden = lstm(packed)
unpacked, unpacked_len = rnn_utils.pad_packed_sequence(packed_out)
# Check forward
self.assertEqual(packed_hidden, seq_hidden)
self.assertEqual(unpacked, seq_out)
self.assertEqual(unpacked_len, lengths)
# Check backward
seq_out.sum().backward()
grad_x = x_leaf.grad.data.clone()
x_leaf.grad.data.zero_()
unpacked.sum().backward()
self.assertEqual(x_leaf.grad.data, grad_x)
for p1, p2 in zip(lstm.parameters(), lstm2.parameters()):
self.assertEqual(p1.grad, p2.grad)
评论列表
文章目录