def reverse_sequences_torch(mini_batch, seq_lengths):
reversed_mini_batch = ng_zeros(mini_batch.size(), type_as=mini_batch.data)
for b in range(mini_batch.size(0)):
T = seq_lengths[b]
time_slice = np.arange(T - 1, -1, -1)
time_slice = Variable(torch.cuda.LongTensor(time_slice)) if 'cuda' in mini_batch.data.type() \
else Variable(torch.LongTensor(time_slice))
reversed_sequence = torch.index_select(mini_batch[b, :, :], 0, time_slice)
reversed_mini_batch[b, 0:T, :] = reversed_sequence
return reversed_mini_batch
# this function takes the hidden state as output by the PyTorch rnn and
# unpacks it it; it also reverses each sequence temporally
评论列表
文章目录